Skip to main content

mistralrs_core/pipeline/
paths.rs

1use std::{
2    collections::HashMap,
3    fs,
4    path::{Path, PathBuf},
5};
6
7use anyhow::Result;
8use either::Either;
9use hf_hub::{
10    api::sync::{ApiBuilder, ApiRepo},
11    Repo, RepoType,
12};
13use regex_automata::meta::Regex;
14use serde_json::Value;
15use tracing::{info, warn};
16
17use crate::{
18    api_dir_list, api_get_file,
19    lora::LoraConfig,
20    pipeline::{
21        chat_template::{ChatTemplate, ChatTemplateValue},
22        isq::UQFF_RESIDUAL_SAFETENSORS,
23    },
24    utils::tokens::get_token,
25    xlora_models::XLoraConfig,
26    ModelPaths, Ordering, TokenSource, GLOBAL_HF_CACHE,
27};
28
29// Match files against these, avoids situations like `consolidated.safetensors`
30const SAFETENSOR_MATCH: &str = r"model-\d+-of-\d+\.safetensors\b";
31const QUANT_SAFETENSOR_MATCH: &str = r"model\.safetensors\b";
32const PICKLE_MATCH: &str = r"pytorch_model-\d{5}-of-\d{5}.((pth)|(pt)|(bin))\b";
33
34#[derive(Clone, Debug)]
35pub struct LoraAdapterPaths {
36    pub lora_config: mistralrs_quant::LoraConfig,
37    pub adapter_path: PathBuf,
38}
39
40#[allow(clippy::large_enum_variant)]
41#[derive(Clone, Debug)]
42pub enum AdapterPaths {
43    XLora {
44        adapter_configs: Option<Vec<((String, String), LoraConfig)>>,
45        adapter_safetensors: Option<Vec<(String, PathBuf)>>,
46        classifier_path: Option<PathBuf>,
47        xlora_order: Option<Ordering>,
48        xlora_config: Option<XLoraConfig>,
49        lora_preload_adapter_info: Option<HashMap<String, (PathBuf, LoraConfig)>>,
50    },
51    Lora(Vec<LoraAdapterPaths>),
52    None,
53}
54
55pub fn get_xlora_paths(
56    base_model_id: String,
57    xlora_model_id: Option<&String>,
58    lora_adapter_ids: Option<&Vec<String>>,
59    token_source: &TokenSource,
60    revision: String,
61    xlora_order: Option<&Ordering>,
62) -> Result<AdapterPaths> {
63    match (lora_adapter_ids, xlora_model_id, xlora_order) {
64        (None, Some(xlora_id), Some(xlora_order)) => {
65            let api = {
66                let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
67                let mut api = ApiBuilder::from_cache(cache)
68                    .with_progress(true)
69                    .with_token(get_token(token_source)?);
70                if let Some(cache_dir) = crate::hf_hub_cache_dir() {
71                    api = api.with_cache_dir(cache_dir);
72                }
73                api.build().map_err(candle_core::Error::msg)?
74            };
75            let api = api.repo(Repo::with_revision(
76                xlora_id.clone(),
77                RepoType::Model,
78                revision,
79            ));
80            let model_id = Path::new(&xlora_id);
81            let dir_list = api_dir_list!(api, model_id, true).collect::<Vec<_>>();
82            // Get the path for the xlora classifier
83            let xlora_classifier = &dir_list
84                .clone()
85                .into_iter()
86                .filter(|x| x.contains("xlora_classifier.safetensors"))
87                .collect::<Vec<_>>();
88            if xlora_classifier.len() > 1 {
89                warn!("Detected multiple X-LoRA classifiers: {xlora_classifier:?}");
90                warn!("Selected classifier: `{}`", &xlora_classifier[0]);
91            }
92            let xlora_classifier = xlora_classifier.first();
93
94            let classifier_path = xlora_classifier
95                .map(|xlora_classifier| -> candle_core::Result<_> {
96                    Ok(api_get_file!(api, xlora_classifier, model_id))
97                })
98                .transpose()?;
99
100            // Get the path for the xlora config by checking all for valid versions.
101            // NOTE(EricLBuehler): Remove this functionality because all configs should be deserializable
102            let xlora_configs = &dir_list
103                .clone()
104                .into_iter()
105                .filter(|x| x.contains("xlora_config.json"))
106                .collect::<Vec<_>>();
107            if xlora_configs.len() > 1 {
108                warn!("Detected multiple X-LoRA configs: {xlora_configs:?}");
109            }
110
111            let mut xlora_config: Option<XLoraConfig> = None;
112            let mut last_err: Option<serde_json::Error> = None;
113            for (i, config_path) in xlora_configs.iter().enumerate() {
114                if xlora_configs.len() != 1 {
115                    warn!("Selecting config: `{}`", config_path);
116                }
117                let config_path = api_get_file!(api, config_path, model_id);
118                let conf = fs::read_to_string(config_path)?;
119                let deser: Result<XLoraConfig, serde_json::Error> = serde_json::from_str(&conf);
120                match deser {
121                    Ok(conf) => {
122                        xlora_config = Some(conf);
123                        break;
124                    }
125                    Err(e) => {
126                        if i != xlora_configs.len() - 1 {
127                            warn!("Config is broken with error `{e}`");
128                        }
129                        last_err = Some(e);
130                    }
131                }
132            }
133            let xlora_config = xlora_config.map(Some).unwrap_or_else(|| {
134                if let Some(last_err) = last_err {
135                    panic!("Unable to derserialize any configs. Last error: {last_err}")
136                } else {
137                    None
138                }
139            });
140
141            // If there are adapters in the ordering file, get their names and remote paths
142            let adapter_files = dir_list
143                .into_iter()
144                .filter_map(|name| {
145                    if let Some(ref adapters) = xlora_order.adapters {
146                        for adapter_name in adapters {
147                            if name.contains(adapter_name) {
148                                return Some((name, adapter_name.clone()));
149                            }
150                        }
151                    }
152                    None
153                })
154                .collect::<Vec<_>>();
155            if adapter_files.is_empty() && xlora_order.adapters.is_some() {
156                anyhow::bail!("Adapter files are empty. Perhaps the ordering file adapters does not match the actual adapters?")
157            }
158
159            // Get the local paths for each adapter
160            let mut adapters_paths: HashMap<String, Vec<PathBuf>> = HashMap::new();
161            for (file, name) in adapter_files {
162                if let Some(paths) = adapters_paths.get_mut(&name) {
163                    paths.push(api_get_file!(api, &file, model_id));
164                } else {
165                    adapters_paths.insert(name, vec![api_get_file!(api, &file, model_id)]);
166                }
167            }
168
169            // Sort local paths for the adapter configs and safetensors files
170            let mut adapters_configs = Vec::new();
171            let mut adapters_safetensors = Vec::new();
172            if let Some(ref adapters) = xlora_order.adapters {
173                for (i, name) in adapters.iter().enumerate() {
174                    let paths = adapters_paths
175                        .get(name)
176                        .unwrap_or_else(|| panic!("Adapter {name} not found."));
177                    for path in paths {
178                        if path.extension().unwrap() == "safetensors" {
179                            adapters_safetensors.push((name.clone(), path.to_owned()));
180                        } else {
181                            let conf = fs::read_to_string(path)?;
182                            let lora_config: LoraConfig = serde_json::from_str(&conf)?;
183                            adapters_configs
184                                .push((((i + 1).to_string(), name.clone()), lora_config));
185                        }
186                    }
187                }
188            }
189
190            // Make sure they all match
191            if xlora_order.base_model_id
192                != *xlora_config
193                    .as_ref()
194                    .map(|cfg| &cfg.base_model_id)
195                    .unwrap_or(&base_model_id)
196                || xlora_config
197                    .as_ref()
198                    .map(|cfg| &cfg.base_model_id)
199                    .unwrap_or(&base_model_id)
200                    != &base_model_id
201            {
202                anyhow::bail!(
203                    "Adapter ordering file, adapter model config, and base model ID do not match: {}, {}, and {} respectively.",
204                    xlora_order.base_model_id,
205                    xlora_config.map(|cfg| cfg.base_model_id).unwrap_or(base_model_id.clone()),
206                    base_model_id
207                );
208            }
209
210            let lora_preload_adapter_info =
211                // If preload adapters are specified, get their metadata like above
212                if let Some(preload_adapters) = &xlora_order.preload_adapters {
213                    let mut output = HashMap::new();
214                    for adapter in preload_adapters {
215                        // Get the names and remote paths of the files associated with this adapter
216                        let adapter_files = api_dir_list!(api, &adapter.adapter_model_id, true)
217                            .filter_map(|f| {
218                                if f.contains(&adapter.name) {
219                                    Some((f, adapter.name.clone()))
220                                } else {
221                                    None
222                                }
223                            })
224                            .collect::<Vec<_>>();
225                        if adapter_files.is_empty() {
226                            anyhow::bail!("Adapter files are empty. Perhaps the ordering file adapters does not match the actual adapters?")
227                        }
228                        // Get local paths for this adapter
229                        let mut adapters_paths: HashMap<String, Vec<PathBuf>> = HashMap::new();
230                        for (file, name) in adapter_files {
231                            if let Some(paths) = adapters_paths.get_mut(&name) {
232                                paths.push(api_get_file!(api, &file, model_id));
233                            } else {
234                                adapters_paths
235                                    .insert(name, vec![api_get_file!(api, &file, model_id)]);
236                            }
237                        }
238
239                        let mut config = None;
240                        let mut safetensor = None;
241
242                        // Sort local paths for the adapter configs and safetensors files
243                        let paths = adapters_paths
244                            .get(&adapter.name)
245                            .unwrap_or_else(|| panic!("Adapter {} not found.", adapter.name));
246                        for path in paths {
247                            if path.extension().unwrap() == "safetensors" {
248                                safetensor = Some(path.to_owned());
249                            } else {
250                                let conf = fs::read_to_string(path)?;
251                                let lora_config: LoraConfig = serde_json::from_str(&conf)?;
252                                config = Some(lora_config);
253                            }
254                        }
255
256                        let (config, safetensor) = (config.unwrap(), safetensor.unwrap());
257                        output.insert(adapter.name.clone(), (safetensor, config));
258                    }
259                    Some(output)
260                } else {
261                    None
262                };
263
264            Ok(AdapterPaths::XLora {
265                adapter_configs: Some(adapters_configs),
266                adapter_safetensors: Some(adapters_safetensors),
267                classifier_path,
268                xlora_order: Some(xlora_order.clone()),
269                xlora_config,
270                lora_preload_adapter_info,
271            })
272        }
273        (Some(adapter_ids), None, None) => {
274            let mut lora_adapter_paths = Vec::new();
275            for adapter_id in adapter_ids {
276                info!("Loading adapter at `{adapter_id}`");
277
278                let api = {
279                    let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
280                    let mut api = ApiBuilder::from_cache(cache)
281                        .with_progress(true)
282                        .with_token(get_token(token_source)?);
283                    if let Some(cache_dir) = crate::hf_hub_cache_dir() {
284                        api = api.with_cache_dir(cache_dir);
285                    }
286                    api.build().map_err(candle_core::Error::msg)?
287                };
288                let api = api.repo(Repo::with_revision(
289                    adapter_id.clone(),
290                    RepoType::Model,
291                    revision.clone(),
292                ));
293
294                let config_path = api.get("adapter_config.json")?;
295                let adapter_path = api.get("adapter_model.safetensors")?;
296                let lora_config: mistralrs_quant::LoraConfig =
297                    serde_json::from_str(&fs::read_to_string(config_path)?)?;
298
299                lora_adapter_paths.push(LoraAdapterPaths {
300                    lora_config,
301                    adapter_path,
302                });
303            }
304
305            Ok(AdapterPaths::Lora(lora_adapter_paths))
306        }
307        (None, None, None) => Ok(AdapterPaths::None),
308        _ => anyhow::bail!(
309            "Incorrect configuration for an adapter model. Lora and XLora are mutually exclusive."
310        ),
311    }
312}
313
314pub fn get_model_paths(
315    revision: String,
316    token_source: &TokenSource,
317    quantized_model_id: Option<&String>,
318    quantized_filename: Option<&Vec<String>>,
319    api: &ApiRepo,
320    model_id: &Path,
321    loading_from_uqff: bool,
322) -> Result<Vec<PathBuf>> {
323    match quantized_filename {
324        Some(names) => {
325            let id = quantized_model_id.unwrap();
326            let mut files = Vec::new();
327
328            for name in names {
329                let qapi = {
330                    let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
331                    let mut api = ApiBuilder::from_cache(cache)
332                        .with_progress(true)
333                        .with_token(get_token(token_source)?);
334                    if let Some(cache_dir) = crate::hf_hub_cache_dir() {
335                        api = api.with_cache_dir(cache_dir);
336                    }
337                    api.build().map_err(candle_core::Error::msg)?
338                };
339                let qapi = qapi.repo(Repo::with_revision(
340                    id.to_string(),
341                    RepoType::Model,
342                    revision.clone(),
343                ));
344                let model_id = Path::new(&id);
345                files.push(api_get_file!(qapi, name, model_id));
346            }
347            Ok(files)
348        }
349        None => {
350            // We only match these patterns for model names
351            let safetensor_match = Regex::new(SAFETENSOR_MATCH)?;
352            let quant_safetensor_match = Regex::new(QUANT_SAFETENSOR_MATCH)?;
353            let pickle_match = Regex::new(PICKLE_MATCH)?;
354
355            let mut filenames = vec![];
356            let listing = api_dir_list!(api, model_id, true).filter(|x| {
357                safetensor_match.is_match(x)
358                    || pickle_match.is_match(x)
359                    || quant_safetensor_match.is_match(x)
360                    || x == UQFF_RESIDUAL_SAFETENSORS
361            });
362            let safetensors = listing
363                .clone()
364                .filter(|x| x.ends_with(".safetensors"))
365                .collect::<Vec<_>>();
366            let pickles = listing
367                .clone()
368                .filter(|x| x.ends_with(".pth") || x.ends_with(".pt") || x.ends_with(".bin"))
369                .collect::<Vec<_>>();
370            let uqff_residual = listing
371                .clone()
372                .filter(|x| x == UQFF_RESIDUAL_SAFETENSORS)
373                .collect::<Vec<_>>();
374            let files = if !safetensors.is_empty() {
375                // Always prefer safetensors
376                safetensors
377            } else if !pickles.is_empty() {
378                // Fall back to pickle
379                pickles
380            } else if !uqff_residual.is_empty() && loading_from_uqff {
381                uqff_residual
382            } else {
383                anyhow::bail!("Expected file with extension one of .safetensors, .pth, .pt, .bin.");
384            };
385            info!(
386                "Found model weight filenames {:?}",
387                files
388                    .iter()
389                    .map(|x| x.split('/').next_back().unwrap())
390                    .collect::<Vec<_>>()
391            );
392            for rfilename in files {
393                filenames.push(api_get_file!(api, &rfilename, model_id));
394            }
395            Ok(filenames)
396        }
397    }
398}
399
400/// Find and parse the appropriate [`ChatTemplate`], and ensure is has a valid [`ChatTemplate.chat_template`].
401/// If the provided `tokenizer_config.json` from [`ModelPaths.get_template_filename`] does not
402/// have a `chat_template`, use the provided one.
403///
404/// - Uses `chat_template_fallback` if `paths` does not contain a chat template file. This may be a literal or .json file.
405/// - `chat_template_ovrd` (GGUF chat template content) causes the usage of that string chat template initially.
406///   Falls back to `chat_template_file` if it is invalid. *The user must add the bos/unk/eos tokens manually if this
407///   is used.*
408///
409/// THE FOLLOWING IS IGNORED:
410/// After this, if the `chat_template_explicit` filename is specified (a json with one field: "chat_template" OR a jinja file),
411///  the chat template is overwritten with this chat template.
412#[allow(clippy::borrowed_box)]
413pub(crate) fn get_chat_template(
414    paths: &Box<dyn ModelPaths>,
415    jinja_explicit: Option<&String>,
416    chat_template_explicit: Option<&String>,
417    chat_template_fallback: Option<&String>,
418    chat_template_ovrd: Option<String>,
419) -> ChatTemplate {
420    // Get template content, this may be overridden.
421    let template_content = if let Some(template_filename) = paths.get_template_filename() {
422        if !["jinja", "json"].contains(
423            &template_filename
424                .extension()
425                .expect("Template filename must be a file")
426                .to_string_lossy()
427                .to_string()
428                .as_str(),
429        ) {
430            panic!("Template filename {template_filename:?} must end with `.json` or `.jinja`.");
431        }
432        Some(fs::read_to_string(template_filename).expect("Loading chat template failed."))
433    } else if chat_template_fallback.is_some_and(|f| f.ends_with(".json")) {
434        // User specified a file
435        let template_filename = chat_template_fallback
436            .expect("A tokenizer config or chat template file path must be specified.");
437        Some(fs::read_to_string(template_filename).expect("Loading chat template failed."))
438    } else if chat_template_ovrd.is_some() {
439        None
440    } else {
441        panic!("Expected chat template file to end with .json, or you can specify a tokenizer model ID to load the chat template there. If you are running a GGUF model, it probably does not contain a chat template.");
442    };
443    let mut template: ChatTemplate = match chat_template_ovrd {
444        Some(chat_template) => {
445            // In this case the override chat template is being used. The user must add the bos/eos/unk toks themselves.
446            info!("Using literal chat template.");
447            let mut template = ChatTemplate::default();
448            template.chat_template = Some(ChatTemplateValue(Either::Left(chat_template)));
449            template
450        }
451        None => {
452            // Check if template_filename is a .jinja file
453            if let Some(template_filename) = paths.get_template_filename() {
454                if template_filename.extension().map(|e| e.to_str()) == Some(Some("jinja")) {
455                    info!("Using chat template from .jinja file.");
456                    let mut template = ChatTemplate::default();
457                    template.chat_template = Some(ChatTemplateValue(Either::Left(
458                        template_content.as_ref().unwrap().clone(),
459                    )));
460                    template
461                } else {
462                    serde_json::from_str(&template_content.as_ref().unwrap().clone()).unwrap()
463                }
464            } else {
465                serde_json::from_str(&template_content.as_ref().unwrap().clone()).unwrap()
466            }
467        }
468    };
469    // Overwrite to use any present `chat_template.json`, only if there is not one present already.
470    if template.chat_template.is_none() {
471        if let Some(chat_template_explicit) = chat_template_explicit {
472            let ct =
473                fs::read_to_string(chat_template_explicit).expect("Loading chat template failed.");
474
475            let new_chat_template = if chat_template_explicit.ends_with(".jinja") {
476                ct
477            } else {
478                #[derive(Debug, serde::Deserialize)]
479                struct AutomaticTemplate {
480                    chat_template: String,
481                }
482                let deser: AutomaticTemplate = serde_json::from_str(&ct).unwrap();
483                deser.chat_template
484            };
485
486            template.chat_template = Some(ChatTemplateValue(Either::Left(new_chat_template)));
487        }
488    }
489
490    // JINJA explicit
491    if let Some(jinja_explicit) = jinja_explicit {
492        if !jinja_explicit.ends_with(".jinja") {
493            panic!("jinja_explicit must end with .jinja!");
494        }
495
496        let ct = fs::read_to_string(jinja_explicit).expect("Loading chat template failed.");
497
498        template.chat_template = Some(ChatTemplateValue(Either::Left(ct)));
499    }
500
501    let processor_conf: Option<crate::vision_models::processor_config::ProcessorConfig> = paths
502        .get_processor_config()
503        .as_ref()
504        .map(|f| serde_json::from_str(&fs::read_to_string(f).unwrap()).unwrap());
505    if let Some(processor_conf) = processor_conf {
506        if processor_conf.chat_template.is_some() {
507            template.chat_template = processor_conf
508                .chat_template
509                .map(|x| ChatTemplateValue(Either::Left(x)));
510        }
511    }
512
513    #[derive(Debug, serde::Deserialize)]
514    struct SpecifiedTemplate {
515        chat_template: String,
516        bos_token: Option<String>,
517        eos_token: Option<String>,
518        unk_token: Option<String>,
519    }
520
521    if template.chat_template.is_some() {
522        return template;
523    };
524
525    match &template.chat_template {
526        Some(_) => template,
527        None => {
528            info!("`tokenizer_config.json` does not contain a chat template, attempting to use specified JINJA chat template.");
529            let mut deser: HashMap<String, Value> =
530                serde_json::from_str(&template_content.unwrap()).unwrap();
531
532            match chat_template_fallback.cloned() {
533                Some(t) => {
534                    info!("Loading specified loading chat template file at `{t}`.");
535                    let templ: SpecifiedTemplate =
536                        serde_json::from_str(&fs::read_to_string(t.clone()).unwrap()).unwrap();
537                    deser.insert(
538                        "chat_template".to_string(),
539                        Value::String(templ.chat_template),
540                    );
541                    if let Some(bos_token) = templ.bos_token {
542                        deser.insert("bos_token".to_string(), Value::String(bos_token));
543                    }
544                    if let Some(eos_token) = templ.eos_token {
545                        deser.insert("eos_token".to_string(), Value::String(eos_token));
546                    }
547                    if let Some(unk_token) = templ.unk_token {
548                        deser.insert("unk_token".to_string(), Value::String(unk_token));
549                    }
550                }
551                None => {
552                    warn!("No specified chat template. No chat template will be used. Only prompts will be accepted, not messages.");
553                    deser.insert("chat_template".to_string(), Value::Null);
554                }
555            }
556
557            let ser = serde_json::to_string_pretty(&deser)
558                .expect("Serialization of modified chat template failed.");
559            serde_json::from_str(&ser).unwrap()
560        }
561    }
562}
563
564mod tests {
565    #[test]
566    fn match_safetensors() -> anyhow::Result<()> {
567        use regex_automata::meta::Regex;
568
569        use super::SAFETENSOR_MATCH;
570        let safetensor_match = Regex::new(SAFETENSOR_MATCH)?;
571
572        let positive_ids = [
573            "model-00001-of-00001.safetensors",
574            "model-00002-of-00002.safetensors",
575            "model-00003-of-00003.safetensors",
576            "model-00004-of-00004.safetensors",
577            "model-00005-of-00005.safetensors",
578            "model-00006-of-00006.safetensors",
579        ];
580        let negative_ids = [
581            "model-0000a-of-00002.safetensors",
582            "consolidated.safetensors",
583        ];
584        for id in positive_ids {
585            assert!(safetensor_match.is_match(id));
586        }
587        for id in negative_ids {
588            assert!(!safetensor_match.is_match(id));
589        }
590        Ok(())
591    }
592
593    #[test]
594    fn match_pickle() -> anyhow::Result<()> {
595        use regex_automata::meta::Regex;
596
597        use super::PICKLE_MATCH;
598        let pickle_match = Regex::new(PICKLE_MATCH)?;
599
600        let positive_ids = [
601            "pytorch_model-00001-of-00002.bin",
602            "pytorch_model-00002-of-00002.bin",
603        ];
604        let negative_ids = [
605            "pytorch_model-000001-of-00001.bin",
606            "pytorch_model-0000a-of-00002.bin",
607            "pytorch_model-000-of-00003.bin",
608            "pytorch_consolidated.bin",
609        ];
610        for id in positive_ids {
611            assert!(pickle_match.is_match(id));
612        }
613        for id in negative_ids {
614            assert!(!pickle_match.is_match(id));
615        }
616        Ok(())
617    }
618}