1use std::{
2 collections::HashMap,
3 fs,
4 path::{Path, PathBuf},
5};
6
7use anyhow::Result;
8use either::Either;
9use hf_hub::{
10 api::sync::{ApiBuilder, ApiRepo},
11 Repo, RepoType,
12};
13use regex_automata::meta::Regex;
14use serde_json::Value;
15use tracing::{info, warn};
16
17use crate::{
18 api_dir_list, api_get_file,
19 lora::LoraConfig,
20 pipeline::{
21 chat_template::{ChatTemplate, ChatTemplateValue},
22 isq::UQFF_RESIDUAL_SAFETENSORS,
23 },
24 utils::tokens::get_token,
25 xlora_models::XLoraConfig,
26 ModelPaths, Ordering, TokenSource, GLOBAL_HF_CACHE,
27};
28
29const SAFETENSOR_MATCH: &str = r"model-\d+-of-\d+\.safetensors\b";
31const QUANT_SAFETENSOR_MATCH: &str = r"model\.safetensors\b";
32const PICKLE_MATCH: &str = r"pytorch_model-\d{5}-of-\d{5}.((pth)|(pt)|(bin))\b";
33
34#[derive(Clone, Debug)]
35pub struct LoraAdapterPaths {
36 pub lora_config: mistralrs_quant::LoraConfig,
37 pub adapter_path: PathBuf,
38}
39
40#[allow(clippy::large_enum_variant)]
41#[derive(Clone, Debug)]
42pub enum AdapterPaths {
43 XLora {
44 adapter_configs: Option<Vec<((String, String), LoraConfig)>>,
45 adapter_safetensors: Option<Vec<(String, PathBuf)>>,
46 classifier_path: Option<PathBuf>,
47 xlora_order: Option<Ordering>,
48 xlora_config: Option<XLoraConfig>,
49 lora_preload_adapter_info: Option<HashMap<String, (PathBuf, LoraConfig)>>,
50 },
51 Lora(Vec<LoraAdapterPaths>),
52 None,
53}
54
55pub fn get_xlora_paths(
56 base_model_id: String,
57 xlora_model_id: Option<&String>,
58 lora_adapter_ids: Option<&Vec<String>>,
59 token_source: &TokenSource,
60 revision: String,
61 xlora_order: Option<&Ordering>,
62) -> Result<AdapterPaths> {
63 match (lora_adapter_ids, xlora_model_id, xlora_order) {
64 (None, Some(xlora_id), Some(xlora_order)) => {
65 let api = {
66 let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
67 let mut api = ApiBuilder::from_cache(cache)
68 .with_progress(true)
69 .with_token(get_token(token_source)?);
70 if let Some(cache_dir) = crate::hf_hub_cache_dir() {
71 api = api.with_cache_dir(cache_dir);
72 }
73 api.build().map_err(candle_core::Error::msg)?
74 };
75 let api = api.repo(Repo::with_revision(
76 xlora_id.clone(),
77 RepoType::Model,
78 revision,
79 ));
80 let model_id = Path::new(&xlora_id);
81 let dir_list = api_dir_list!(api, model_id, true).collect::<Vec<_>>();
82 let xlora_classifier = &dir_list
84 .clone()
85 .into_iter()
86 .filter(|x| x.contains("xlora_classifier.safetensors"))
87 .collect::<Vec<_>>();
88 if xlora_classifier.len() > 1 {
89 warn!("Detected multiple X-LoRA classifiers: {xlora_classifier:?}");
90 warn!("Selected classifier: `{}`", &xlora_classifier[0]);
91 }
92 let xlora_classifier = xlora_classifier.first();
93
94 let classifier_path = xlora_classifier
95 .map(|xlora_classifier| -> candle_core::Result<_> {
96 Ok(api_get_file!(api, xlora_classifier, model_id))
97 })
98 .transpose()?;
99
100 let xlora_configs = &dir_list
103 .clone()
104 .into_iter()
105 .filter(|x| x.contains("xlora_config.json"))
106 .collect::<Vec<_>>();
107 if xlora_configs.len() > 1 {
108 warn!("Detected multiple X-LoRA configs: {xlora_configs:?}");
109 }
110
111 let mut xlora_config: Option<XLoraConfig> = None;
112 let mut last_err: Option<serde_json::Error> = None;
113 for (i, config_path) in xlora_configs.iter().enumerate() {
114 if xlora_configs.len() != 1 {
115 warn!("Selecting config: `{}`", config_path);
116 }
117 let config_path = api_get_file!(api, config_path, model_id);
118 let conf = fs::read_to_string(config_path)?;
119 let deser: Result<XLoraConfig, serde_json::Error> = serde_json::from_str(&conf);
120 match deser {
121 Ok(conf) => {
122 xlora_config = Some(conf);
123 break;
124 }
125 Err(e) => {
126 if i != xlora_configs.len() - 1 {
127 warn!("Config is broken with error `{e}`");
128 }
129 last_err = Some(e);
130 }
131 }
132 }
133 let xlora_config = xlora_config.map(Some).unwrap_or_else(|| {
134 if let Some(last_err) = last_err {
135 panic!("Unable to derserialize any configs. Last error: {last_err}")
136 } else {
137 None
138 }
139 });
140
141 let adapter_files = dir_list
143 .into_iter()
144 .filter_map(|name| {
145 if let Some(ref adapters) = xlora_order.adapters {
146 for adapter_name in adapters {
147 if name.contains(adapter_name) {
148 return Some((name, adapter_name.clone()));
149 }
150 }
151 }
152 None
153 })
154 .collect::<Vec<_>>();
155 if adapter_files.is_empty() && xlora_order.adapters.is_some() {
156 anyhow::bail!("Adapter files are empty. Perhaps the ordering file adapters does not match the actual adapters?")
157 }
158
159 let mut adapters_paths: HashMap<String, Vec<PathBuf>> = HashMap::new();
161 for (file, name) in adapter_files {
162 if let Some(paths) = adapters_paths.get_mut(&name) {
163 paths.push(api_get_file!(api, &file, model_id));
164 } else {
165 adapters_paths.insert(name, vec![api_get_file!(api, &file, model_id)]);
166 }
167 }
168
169 let mut adapters_configs = Vec::new();
171 let mut adapters_safetensors = Vec::new();
172 if let Some(ref adapters) = xlora_order.adapters {
173 for (i, name) in adapters.iter().enumerate() {
174 let paths = adapters_paths
175 .get(name)
176 .unwrap_or_else(|| panic!("Adapter {name} not found."));
177 for path in paths {
178 if path.extension().unwrap() == "safetensors" {
179 adapters_safetensors.push((name.clone(), path.to_owned()));
180 } else {
181 let conf = fs::read_to_string(path)?;
182 let lora_config: LoraConfig = serde_json::from_str(&conf)?;
183 adapters_configs
184 .push((((i + 1).to_string(), name.clone()), lora_config));
185 }
186 }
187 }
188 }
189
190 if xlora_order.base_model_id
192 != *xlora_config
193 .as_ref()
194 .map(|cfg| &cfg.base_model_id)
195 .unwrap_or(&base_model_id)
196 || xlora_config
197 .as_ref()
198 .map(|cfg| &cfg.base_model_id)
199 .unwrap_or(&base_model_id)
200 != &base_model_id
201 {
202 anyhow::bail!(
203 "Adapter ordering file, adapter model config, and base model ID do not match: {}, {}, and {} respectively.",
204 xlora_order.base_model_id,
205 xlora_config.map(|cfg| cfg.base_model_id).unwrap_or(base_model_id.clone()),
206 base_model_id
207 );
208 }
209
210 let lora_preload_adapter_info =
211 if let Some(preload_adapters) = &xlora_order.preload_adapters {
213 let mut output = HashMap::new();
214 for adapter in preload_adapters {
215 let adapter_files = api_dir_list!(api, &adapter.adapter_model_id, true)
217 .filter_map(|f| {
218 if f.contains(&adapter.name) {
219 Some((f, adapter.name.clone()))
220 } else {
221 None
222 }
223 })
224 .collect::<Vec<_>>();
225 if adapter_files.is_empty() {
226 anyhow::bail!("Adapter files are empty. Perhaps the ordering file adapters does not match the actual adapters?")
227 }
228 let mut adapters_paths: HashMap<String, Vec<PathBuf>> = HashMap::new();
230 for (file, name) in adapter_files {
231 if let Some(paths) = adapters_paths.get_mut(&name) {
232 paths.push(api_get_file!(api, &file, model_id));
233 } else {
234 adapters_paths
235 .insert(name, vec![api_get_file!(api, &file, model_id)]);
236 }
237 }
238
239 let mut config = None;
240 let mut safetensor = None;
241
242 let paths = adapters_paths
244 .get(&adapter.name)
245 .unwrap_or_else(|| panic!("Adapter {} not found.", adapter.name));
246 for path in paths {
247 if path.extension().unwrap() == "safetensors" {
248 safetensor = Some(path.to_owned());
249 } else {
250 let conf = fs::read_to_string(path)?;
251 let lora_config: LoraConfig = serde_json::from_str(&conf)?;
252 config = Some(lora_config);
253 }
254 }
255
256 let (config, safetensor) = (config.unwrap(), safetensor.unwrap());
257 output.insert(adapter.name.clone(), (safetensor, config));
258 }
259 Some(output)
260 } else {
261 None
262 };
263
264 Ok(AdapterPaths::XLora {
265 adapter_configs: Some(adapters_configs),
266 adapter_safetensors: Some(adapters_safetensors),
267 classifier_path,
268 xlora_order: Some(xlora_order.clone()),
269 xlora_config,
270 lora_preload_adapter_info,
271 })
272 }
273 (Some(adapter_ids), None, None) => {
274 let mut lora_adapter_paths = Vec::new();
275 for adapter_id in adapter_ids {
276 info!("Loading adapter at `{adapter_id}`");
277
278 let api = {
279 let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
280 let mut api = ApiBuilder::from_cache(cache)
281 .with_progress(true)
282 .with_token(get_token(token_source)?);
283 if let Some(cache_dir) = crate::hf_hub_cache_dir() {
284 api = api.with_cache_dir(cache_dir);
285 }
286 api.build().map_err(candle_core::Error::msg)?
287 };
288 let api = api.repo(Repo::with_revision(
289 adapter_id.clone(),
290 RepoType::Model,
291 revision.clone(),
292 ));
293
294 let config_path = api.get("adapter_config.json")?;
295 let adapter_path = api.get("adapter_model.safetensors")?;
296 let lora_config: mistralrs_quant::LoraConfig =
297 serde_json::from_str(&fs::read_to_string(config_path)?)?;
298
299 lora_adapter_paths.push(LoraAdapterPaths {
300 lora_config,
301 adapter_path,
302 });
303 }
304
305 Ok(AdapterPaths::Lora(lora_adapter_paths))
306 }
307 (None, None, None) => Ok(AdapterPaths::None),
308 _ => anyhow::bail!(
309 "Incorrect configuration for an adapter model. Lora and XLora are mutually exclusive."
310 ),
311 }
312}
313
314pub fn get_model_paths(
315 revision: String,
316 token_source: &TokenSource,
317 quantized_model_id: Option<&String>,
318 quantized_filename: Option<&Vec<String>>,
319 api: &ApiRepo,
320 model_id: &Path,
321 loading_from_uqff: bool,
322) -> Result<Vec<PathBuf>> {
323 match quantized_filename {
324 Some(names) => {
325 let id = quantized_model_id.unwrap();
326 let mut files = Vec::new();
327
328 for name in names {
329 let qapi = {
330 let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
331 let mut api = ApiBuilder::from_cache(cache)
332 .with_progress(true)
333 .with_token(get_token(token_source)?);
334 if let Some(cache_dir) = crate::hf_hub_cache_dir() {
335 api = api.with_cache_dir(cache_dir);
336 }
337 api.build().map_err(candle_core::Error::msg)?
338 };
339 let qapi = qapi.repo(Repo::with_revision(
340 id.to_string(),
341 RepoType::Model,
342 revision.clone(),
343 ));
344 let model_id = Path::new(&id);
345 files.push(api_get_file!(qapi, name, model_id));
346 }
347 Ok(files)
348 }
349 None => {
350 let safetensor_match = Regex::new(SAFETENSOR_MATCH)?;
352 let quant_safetensor_match = Regex::new(QUANT_SAFETENSOR_MATCH)?;
353 let pickle_match = Regex::new(PICKLE_MATCH)?;
354
355 let mut filenames = vec![];
356 let listing = api_dir_list!(api, model_id, true).filter(|x| {
357 safetensor_match.is_match(x)
358 || pickle_match.is_match(x)
359 || quant_safetensor_match.is_match(x)
360 || x == UQFF_RESIDUAL_SAFETENSORS
361 });
362 let safetensors = listing
363 .clone()
364 .filter(|x| x.ends_with(".safetensors"))
365 .collect::<Vec<_>>();
366 let pickles = listing
367 .clone()
368 .filter(|x| x.ends_with(".pth") || x.ends_with(".pt") || x.ends_with(".bin"))
369 .collect::<Vec<_>>();
370 let uqff_residual = listing
371 .clone()
372 .filter(|x| x == UQFF_RESIDUAL_SAFETENSORS)
373 .collect::<Vec<_>>();
374 let files = if !safetensors.is_empty() {
375 safetensors
377 } else if !pickles.is_empty() {
378 pickles
380 } else if !uqff_residual.is_empty() && loading_from_uqff {
381 uqff_residual
382 } else {
383 anyhow::bail!("Expected file with extension one of .safetensors, .pth, .pt, .bin.");
384 };
385 info!(
386 "Found model weight filenames {:?}",
387 files
388 .iter()
389 .map(|x| x.split('/').next_back().unwrap())
390 .collect::<Vec<_>>()
391 );
392 for rfilename in files {
393 filenames.push(api_get_file!(api, &rfilename, model_id));
394 }
395 Ok(filenames)
396 }
397 }
398}
399
400#[allow(clippy::borrowed_box)]
413pub(crate) fn get_chat_template(
414 paths: &Box<dyn ModelPaths>,
415 jinja_explicit: Option<&String>,
416 chat_template_explicit: Option<&String>,
417 chat_template_fallback: Option<&String>,
418 chat_template_ovrd: Option<String>,
419) -> ChatTemplate {
420 let template_content = if let Some(template_filename) = paths.get_template_filename() {
422 if !["jinja", "json"].contains(
423 &template_filename
424 .extension()
425 .expect("Template filename must be a file")
426 .to_string_lossy()
427 .to_string()
428 .as_str(),
429 ) {
430 panic!("Template filename {template_filename:?} must end with `.json` or `.jinja`.");
431 }
432 Some(fs::read_to_string(template_filename).expect("Loading chat template failed."))
433 } else if chat_template_fallback.is_some_and(|f| f.ends_with(".json")) {
434 let template_filename = chat_template_fallback
436 .expect("A tokenizer config or chat template file path must be specified.");
437 Some(fs::read_to_string(template_filename).expect("Loading chat template failed."))
438 } else if chat_template_ovrd.is_some() {
439 None
440 } else {
441 panic!("Expected chat template file to end with .json, or you can specify a tokenizer model ID to load the chat template there. If you are running a GGUF model, it probably does not contain a chat template.");
442 };
443 let mut template: ChatTemplate = match chat_template_ovrd {
444 Some(chat_template) => {
445 info!("Using literal chat template.");
447 let mut template = ChatTemplate::default();
448 template.chat_template = Some(ChatTemplateValue(Either::Left(chat_template)));
449 template
450 }
451 None => {
452 if let Some(template_filename) = paths.get_template_filename() {
454 if template_filename.extension().map(|e| e.to_str()) == Some(Some("jinja")) {
455 info!("Using chat template from .jinja file.");
456 let mut template = ChatTemplate::default();
457 template.chat_template = Some(ChatTemplateValue(Either::Left(
458 template_content.as_ref().unwrap().clone(),
459 )));
460 template
461 } else {
462 serde_json::from_str(&template_content.as_ref().unwrap().clone()).unwrap()
463 }
464 } else {
465 serde_json::from_str(&template_content.as_ref().unwrap().clone()).unwrap()
466 }
467 }
468 };
469 if template.chat_template.is_none() {
471 if let Some(chat_template_explicit) = chat_template_explicit {
472 let ct =
473 fs::read_to_string(chat_template_explicit).expect("Loading chat template failed.");
474
475 let new_chat_template = if chat_template_explicit.ends_with(".jinja") {
476 ct
477 } else {
478 #[derive(Debug, serde::Deserialize)]
479 struct AutomaticTemplate {
480 chat_template: String,
481 }
482 let deser: AutomaticTemplate = serde_json::from_str(&ct).unwrap();
483 deser.chat_template
484 };
485
486 template.chat_template = Some(ChatTemplateValue(Either::Left(new_chat_template)));
487 }
488 }
489
490 if let Some(jinja_explicit) = jinja_explicit {
492 if !jinja_explicit.ends_with(".jinja") {
493 panic!("jinja_explicit must end with .jinja!");
494 }
495
496 let ct = fs::read_to_string(jinja_explicit).expect("Loading chat template failed.");
497
498 template.chat_template = Some(ChatTemplateValue(Either::Left(ct)));
499 }
500
501 let processor_conf: Option<crate::vision_models::processor_config::ProcessorConfig> = paths
502 .get_processor_config()
503 .as_ref()
504 .map(|f| serde_json::from_str(&fs::read_to_string(f).unwrap()).unwrap());
505 if let Some(processor_conf) = processor_conf {
506 if processor_conf.chat_template.is_some() {
507 template.chat_template = processor_conf
508 .chat_template
509 .map(|x| ChatTemplateValue(Either::Left(x)));
510 }
511 }
512
513 #[derive(Debug, serde::Deserialize)]
514 struct SpecifiedTemplate {
515 chat_template: String,
516 bos_token: Option<String>,
517 eos_token: Option<String>,
518 unk_token: Option<String>,
519 }
520
521 if template.chat_template.is_some() {
522 return template;
523 };
524
525 match &template.chat_template {
526 Some(_) => template,
527 None => {
528 info!("`tokenizer_config.json` does not contain a chat template, attempting to use specified JINJA chat template.");
529 let mut deser: HashMap<String, Value> =
530 serde_json::from_str(&template_content.unwrap()).unwrap();
531
532 match chat_template_fallback.cloned() {
533 Some(t) => {
534 info!("Loading specified loading chat template file at `{t}`.");
535 let templ: SpecifiedTemplate =
536 serde_json::from_str(&fs::read_to_string(t.clone()).unwrap()).unwrap();
537 deser.insert(
538 "chat_template".to_string(),
539 Value::String(templ.chat_template),
540 );
541 if let Some(bos_token) = templ.bos_token {
542 deser.insert("bos_token".to_string(), Value::String(bos_token));
543 }
544 if let Some(eos_token) = templ.eos_token {
545 deser.insert("eos_token".to_string(), Value::String(eos_token));
546 }
547 if let Some(unk_token) = templ.unk_token {
548 deser.insert("unk_token".to_string(), Value::String(unk_token));
549 }
550 }
551 None => {
552 warn!("No specified chat template. No chat template will be used. Only prompts will be accepted, not messages.");
553 deser.insert("chat_template".to_string(), Value::Null);
554 }
555 }
556
557 let ser = serde_json::to_string_pretty(&deser)
558 .expect("Serialization of modified chat template failed.");
559 serde_json::from_str(&ser).unwrap()
560 }
561 }
562}
563
564mod tests {
565 #[test]
566 fn match_safetensors() -> anyhow::Result<()> {
567 use regex_automata::meta::Regex;
568
569 use super::SAFETENSOR_MATCH;
570 let safetensor_match = Regex::new(SAFETENSOR_MATCH)?;
571
572 let positive_ids = [
573 "model-00001-of-00001.safetensors",
574 "model-00002-of-00002.safetensors",
575 "model-00003-of-00003.safetensors",
576 "model-00004-of-00004.safetensors",
577 "model-00005-of-00005.safetensors",
578 "model-00006-of-00006.safetensors",
579 ];
580 let negative_ids = [
581 "model-0000a-of-00002.safetensors",
582 "consolidated.safetensors",
583 ];
584 for id in positive_ids {
585 assert!(safetensor_match.is_match(id));
586 }
587 for id in negative_ids {
588 assert!(!safetensor_match.is_match(id));
589 }
590 Ok(())
591 }
592
593 #[test]
594 fn match_pickle() -> anyhow::Result<()> {
595 use regex_automata::meta::Regex;
596
597 use super::PICKLE_MATCH;
598 let pickle_match = Regex::new(PICKLE_MATCH)?;
599
600 let positive_ids = [
601 "pytorch_model-00001-of-00002.bin",
602 "pytorch_model-00002-of-00002.bin",
603 ];
604 let negative_ids = [
605 "pytorch_model-000001-of-00001.bin",
606 "pytorch_model-0000a-of-00002.bin",
607 "pytorch_model-000-of-00003.bin",
608 "pytorch_consolidated.bin",
609 ];
610 for id in positive_ids {
611 assert!(pickle_match.is_match(id));
612 }
613 for id in negative_ids {
614 assert!(!pickle_match.is_match(id));
615 }
616 Ok(())
617 }
618}