1use std::{
2 collections::HashMap,
3 fs,
4 path::{Path, PathBuf},
5};
6
7use anyhow::Result;
8use either::Either;
9use hf_hub::{
10 api::sync::{ApiBuilder, ApiRepo},
11 Repo, RepoType,
12};
13use regex_automata::meta::Regex;
14use serde_json::Value;
15use tracing::{debug, info, trace, warn};
16
17use crate::{
18 api_dir_list, api_get_file,
19 lora::LoraConfig,
20 pipeline::{
21 chat_template::{BeginEndUnkPadTok, ChatTemplate, ChatTemplateValue},
22 isq::UQFF_RESIDUAL_SAFETENSORS,
23 },
24 utils::tokens::get_token,
25 xlora_models::XLoraConfig,
26 ModelPaths, Ordering, TokenSource, GLOBAL_HF_CACHE,
27};
28
29const SAFETENSOR_MATCH: &str = r"model-\d+-of-\d+\.safetensors\b";
31const QUANT_SAFETENSOR_MATCH: &str = r"model\.safetensors\b";
32const CONSOLIDATED_SAFETENSOR_MATCH: &str = r"consolidated\.safetensors\b";
33const PICKLE_MATCH: &str = r"pytorch_model-\d{5}-of-\d{5}.((pth)|(pt)|(bin))\b";
34
35#[derive(Clone, Debug)]
36pub struct LoraAdapterPaths {
37 pub lora_config: hanzo_quant::LoraConfig,
38 pub adapter_path: PathBuf,
39}
40
41#[allow(clippy::large_enum_variant)]
42#[derive(Clone, Debug)]
43pub enum AdapterPaths {
44 XLora {
45 adapter_configs: Option<Vec<((String, String), LoraConfig)>>,
46 adapter_safetensors: Option<Vec<(String, PathBuf)>>,
47 classifier_path: Option<PathBuf>,
48 xlora_order: Option<Ordering>,
49 xlora_config: Option<XLoraConfig>,
50 lora_preload_adapter_info: Option<HashMap<String, (PathBuf, LoraConfig)>>,
51 },
52 Lora(Vec<LoraAdapterPaths>),
53 None,
54}
55
56pub fn get_xlora_paths(
57 base_model_id: String,
58 xlora_model_id: Option<&String>,
59 lora_adapter_ids: Option<&Vec<String>>,
60 token_source: &TokenSource,
61 revision: String,
62 xlora_order: Option<&Ordering>,
63) -> Result<AdapterPaths> {
64 match (lora_adapter_ids, xlora_model_id, xlora_order) {
65 (None, Some(xlora_id), Some(xlora_order)) => {
66 let api = {
67 let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
68 let mut api = ApiBuilder::from_cache(cache)
69 .with_progress(true)
70 .with_token(get_token(token_source)?);
71 if let Some(cache_dir) = crate::hf_hub_cache_dir() {
72 api = api.with_cache_dir(cache_dir);
73 }
74 api.build().map_err(hanzo_ml::Error::msg)?
75 };
76 let api = api.repo(Repo::with_revision(
77 xlora_id.clone(),
78 RepoType::Model,
79 revision.clone(),
80 ));
81 let model_id = Path::new(&xlora_id);
82 let dir_list = api_dir_list!(api, model_id, true, &revision).collect::<Vec<_>>();
83 let xlora_classifier = &dir_list
85 .clone()
86 .into_iter()
87 .filter(|x| x.contains("xlora_classifier.safetensors"))
88 .collect::<Vec<_>>();
89 if xlora_classifier.len() > 1 {
90 warn!("Detected multiple X-LoRA classifiers: {xlora_classifier:?}");
91 warn!("Selected classifier: `{}`", &xlora_classifier[0]);
92 }
93 let xlora_classifier = xlora_classifier.first();
94
95 let classifier_path = xlora_classifier
96 .map(|xlora_classifier| -> hanzo_ml::Result<_> {
97 Ok(api_get_file!(api, xlora_classifier, model_id, &revision))
98 })
99 .transpose()?;
100
101 let xlora_configs = &dir_list
104 .clone()
105 .into_iter()
106 .filter(|x| x.contains("xlora_config.json"))
107 .collect::<Vec<_>>();
108 if xlora_configs.len() > 1 {
109 warn!("Detected multiple X-LoRA configs: {xlora_configs:?}");
110 }
111
112 let mut xlora_config: Option<XLoraConfig> = None;
113 let mut last_err: Option<serde_json::Error> = None;
114 for (i, config_path) in xlora_configs.iter().enumerate() {
115 if xlora_configs.len() != 1 {
116 warn!("Selecting config: `{}`", config_path);
117 }
118 let config_path = api_get_file!(api, config_path, model_id, &revision);
119 let conf = fs::read_to_string(config_path)?;
120 let deser: Result<XLoraConfig, serde_json::Error> = serde_json::from_str(&conf);
121 match deser {
122 Ok(conf) => {
123 xlora_config = Some(conf);
124 break;
125 }
126 Err(e) => {
127 if i != xlora_configs.len() - 1 {
128 warn!("Config is broken with error `{e}`");
129 }
130 last_err = Some(e);
131 }
132 }
133 }
134 let xlora_config = xlora_config.map(Some).unwrap_or_else(|| {
135 if let Some(last_err) = last_err {
136 panic!("Unable to derserialize any configs. Last error: {last_err}")
137 } else {
138 None
139 }
140 });
141
142 let adapter_files = dir_list
144 .into_iter()
145 .filter_map(|name| {
146 if let Some(ref adapters) = xlora_order.adapters {
147 for adapter_name in adapters {
148 if name.contains(adapter_name) {
149 return Some((name, adapter_name.clone()));
150 }
151 }
152 }
153 None
154 })
155 .collect::<Vec<_>>();
156 if adapter_files.is_empty() && xlora_order.adapters.is_some() {
157 anyhow::bail!("Adapter files are empty. Perhaps the ordering file adapters does not match the actual adapters?")
158 }
159
160 let mut adapters_paths: HashMap<String, Vec<PathBuf>> = HashMap::new();
162 for (file, name) in adapter_files {
163 if let Some(paths) = adapters_paths.get_mut(&name) {
164 paths.push(api_get_file!(api, &file, model_id, &revision));
165 } else {
166 adapters_paths
167 .insert(name, vec![api_get_file!(api, &file, model_id, &revision)]);
168 }
169 }
170
171 let mut adapters_configs = Vec::new();
173 let mut adapters_safetensors = Vec::new();
174 if let Some(ref adapters) = xlora_order.adapters {
175 for (i, name) in adapters.iter().enumerate() {
176 let paths = adapters_paths
177 .get(name)
178 .unwrap_or_else(|| panic!("Adapter {name} not found."));
179 for path in paths {
180 if path.extension().unwrap() == "safetensors" {
181 adapters_safetensors.push((name.clone(), path.to_owned()));
182 } else {
183 let conf = fs::read_to_string(path)?;
184 let lora_config: LoraConfig = serde_json::from_str(&conf)?;
185 adapters_configs
186 .push((((i + 1).to_string(), name.clone()), lora_config));
187 }
188 }
189 }
190 }
191
192 if xlora_order.base_model_id
194 != *xlora_config
195 .as_ref()
196 .map(|cfg| &cfg.base_model_id)
197 .unwrap_or(&base_model_id)
198 || xlora_config
199 .as_ref()
200 .map(|cfg| &cfg.base_model_id)
201 .unwrap_or(&base_model_id)
202 != &base_model_id
203 {
204 anyhow::bail!(
205 "Adapter ordering file, adapter model config, and base model ID do not match: {}, {}, and {} respectively.",
206 xlora_order.base_model_id,
207 xlora_config.map(|cfg| cfg.base_model_id).unwrap_or(base_model_id.clone()),
208 base_model_id
209 );
210 }
211
212 let lora_preload_adapter_info =
213 if let Some(preload_adapters) = &xlora_order.preload_adapters {
215 let mut output = HashMap::new();
216 for adapter in preload_adapters {
217 let adapter_files = api_dir_list!(api, &adapter.adapter_model_id, true, &revision)
219 .filter_map(|f| {
220 if f.contains(&adapter.name) {
221 Some((f, adapter.name.clone()))
222 } else {
223 None
224 }
225 })
226 .collect::<Vec<_>>();
227 if adapter_files.is_empty() {
228 anyhow::bail!("Adapter files are empty. Perhaps the ordering file adapters does not match the actual adapters?")
229 }
230 let mut adapters_paths: HashMap<String, Vec<PathBuf>> = HashMap::new();
232 for (file, name) in adapter_files {
233 if let Some(paths) = adapters_paths.get_mut(&name) {
234 paths.push(api_get_file!(api, &file, model_id, &revision));
235 } else {
236 adapters_paths
237 .insert(name, vec![api_get_file!(api, &file, model_id, &revision)]);
238 }
239 }
240
241 let mut config = None;
242 let mut safetensor = None;
243
244 let paths = adapters_paths
246 .get(&adapter.name)
247 .unwrap_or_else(|| panic!("Adapter {} not found.", adapter.name));
248 for path in paths {
249 if path.extension().unwrap() == "safetensors" {
250 safetensor = Some(path.to_owned());
251 } else {
252 let conf = fs::read_to_string(path)?;
253 let lora_config: LoraConfig = serde_json::from_str(&conf)?;
254 config = Some(lora_config);
255 }
256 }
257
258 let (config, safetensor) = (config.unwrap(), safetensor.unwrap());
259 output.insert(adapter.name.clone(), (safetensor, config));
260 }
261 Some(output)
262 } else {
263 None
264 };
265
266 Ok(AdapterPaths::XLora {
267 adapter_configs: Some(adapters_configs),
268 adapter_safetensors: Some(adapters_safetensors),
269 classifier_path,
270 xlora_order: Some(xlora_order.clone()),
271 xlora_config,
272 lora_preload_adapter_info,
273 })
274 }
275 (Some(adapter_ids), None, None) => {
276 let mut lora_adapter_paths = Vec::new();
277 for adapter_id in adapter_ids {
278 info!("Loading adapter at `{adapter_id}`");
279
280 let api = {
281 let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
282 let mut api = ApiBuilder::from_cache(cache)
283 .with_progress(true)
284 .with_token(get_token(token_source)?);
285 if let Some(cache_dir) = crate::hf_hub_cache_dir() {
286 api = api.with_cache_dir(cache_dir);
287 }
288 api.build().map_err(hanzo_ml::Error::msg)?
289 };
290 let api = api.repo(Repo::with_revision(
291 adapter_id.clone(),
292 RepoType::Model,
293 revision.clone(),
294 ));
295
296 let adapter_path_buf = std::path::Path::new(adapter_id);
297 let config_path = crate::pipeline::hf::get_file(
298 &api,
299 adapter_path_buf,
300 "adapter_config.json",
301 &revision,
302 )?;
303 let adapter_path = crate::pipeline::hf::get_file(
304 &api,
305 adapter_path_buf,
306 "adapter_model.safetensors",
307 &revision,
308 )?;
309 let lora_config: hanzo_quant::LoraConfig =
310 serde_json::from_str(&fs::read_to_string(config_path)?)?;
311
312 lora_adapter_paths.push(LoraAdapterPaths {
313 lora_config,
314 adapter_path,
315 });
316 }
317
318 Ok(AdapterPaths::Lora(lora_adapter_paths))
319 }
320 (None, None, None) => Ok(AdapterPaths::None),
321 _ => anyhow::bail!(
322 "Incorrect configuration for an adapter model. Lora and XLora are mutually exclusive."
323 ),
324 }
325}
326
327pub fn get_model_paths(
328 revision: String,
329 token_source: &TokenSource,
330 quantized_model_id: Option<&String>,
331 quantized_filename: Option<&Vec<String>>,
332 api: &ApiRepo,
333 model_id: &Path,
334 loading_from_uqff: bool,
335) -> Result<Vec<PathBuf>> {
336 match quantized_filename {
337 Some(names) => {
338 let id = quantized_model_id.unwrap();
339 let mut files = Vec::new();
340
341 for name in names {
342 let qapi = {
343 let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
344 let mut api = ApiBuilder::from_cache(cache)
345 .with_progress(true)
346 .with_token(get_token(token_source)?);
347 if let Some(cache_dir) = crate::hf_hub_cache_dir() {
348 api = api.with_cache_dir(cache_dir);
349 }
350 api.build().map_err(hanzo_ml::Error::msg)?
351 };
352 let qapi = qapi.repo(Repo::with_revision(
353 id.to_string(),
354 RepoType::Model,
355 revision.clone(),
356 ));
357 let model_id = Path::new(&id);
358 files.push(api_get_file!(qapi, name, model_id, &revision));
359 }
360 Ok(files)
361 }
362 None => {
363 let safetensor_match = Regex::new(SAFETENSOR_MATCH)?;
365 let quant_safetensor_match = Regex::new(QUANT_SAFETENSOR_MATCH)?;
366 let consolidated_safetensor_match = Regex::new(CONSOLIDATED_SAFETENSOR_MATCH)?;
367 let pickle_match = Regex::new(PICKLE_MATCH)?;
368
369 let mut filenames = vec![];
370 let listing = api_dir_list!(api, model_id, true, &revision).filter(|x| {
371 safetensor_match.is_match(x)
372 || pickle_match.is_match(x)
373 || quant_safetensor_match.is_match(x)
374 || consolidated_safetensor_match.is_match(x)
375 || x == UQFF_RESIDUAL_SAFETENSORS
376 });
377 let safetensors = listing
378 .clone()
379 .filter(|x| x.ends_with(".safetensors"))
380 .collect::<Vec<_>>();
381 let pickles = listing
382 .clone()
383 .filter(|x| x.ends_with(".pth") || x.ends_with(".pt") || x.ends_with(".bin"))
384 .collect::<Vec<_>>();
385 let uqff_residual = listing
386 .clone()
387 .filter(|x| x == UQFF_RESIDUAL_SAFETENSORS)
388 .collect::<Vec<_>>();
389 let files = if !safetensors.is_empty() {
390 safetensors
392 } else if !pickles.is_empty() {
393 pickles
395 } else if !uqff_residual.is_empty() && loading_from_uqff {
396 uqff_residual
397 } else {
398 anyhow::bail!("Expected file with extension one of .safetensors, .pth, .pt, .bin.");
399 };
400 trace!(
401 "Found model weight filenames {:?}",
402 files
403 .iter()
404 .map(|x| x.split('/').next_back().unwrap())
405 .collect::<Vec<_>>()
406 );
407 for rfilename in files {
408 filenames.push(api_get_file!(api, &rfilename, model_id, &revision));
409 }
410 Ok(filenames)
411 }
412 }
413}
414
415#[allow(clippy::borrowed_box)]
428pub(crate) fn get_chat_template(
429 paths: &Box<dyn ModelPaths>,
430 jinja_explicit: Option<&String>,
431 chat_template_explicit: Option<&String>,
432 chat_template_fallback: Option<&String>,
433 chat_template_ovrd: Option<String>,
434) -> ChatTemplate {
435 let template_content = if let Some(template_filename) = paths.get_template_filename() {
437 if !["jinja", "json"].contains(
438 &template_filename
439 .extension()
440 .expect("Template filename must be a file")
441 .to_string_lossy()
442 .to_string()
443 .as_str(),
444 ) {
445 panic!("Template filename {template_filename:?} must end with `.json` or `.jinja`.");
446 }
447 Some(fs::read_to_string(template_filename).expect("Loading chat template failed."))
448 } else if chat_template_fallback.is_some_and(|f| f.ends_with(".json")) {
449 let template_filename = chat_template_fallback
451 .expect("A tokenizer config or chat template file path must be specified.");
452 Some(fs::read_to_string(template_filename).expect("Loading chat template failed."))
453 } else if chat_template_ovrd.is_some() {
454 None
455 } else {
456 debug!("No chat template file found. Chat template may be set via `chat_template.json` or processor config.");
457 None
458 };
459 let mut template: ChatTemplate = match chat_template_ovrd {
460 Some(chat_template) => {
461 debug!("Using literal chat template.");
463 let mut template = ChatTemplate::default();
464 template.chat_template = Some(ChatTemplateValue(Either::Left(chat_template)));
465 template
466 }
467 None => {
468 if let Some(ref content) = template_content {
469 if let Some(template_filename) = paths.get_template_filename() {
471 if template_filename.extension().map(|e| e.to_str()) == Some(Some("jinja")) {
472 debug!("Using chat template from .jinja file.");
473 let mut template = template_filename
477 .parent()
478 .map(|dir| dir.join("tokenizer_config.json"))
479 .filter(|p| p.exists())
480 .and_then(|p| fs::read_to_string(p).ok())
481 .and_then(|s| serde_json::from_str::<ChatTemplate>(&s).ok())
482 .unwrap_or_else(|| {
483 let mut ct = ChatTemplate::default();
487 if let Some(tok_path) = paths
488 .get_tokenizer_filename()
489 .parent()
490 .map(|d| d.join("tokenizer.json"))
491 .filter(|p| p.exists())
492 .or_else(|| {
493 template_filename
494 .parent()
495 .map(|d| d.join("tokenizer.json"))
496 .filter(|p| p.exists())
497 })
498 {
499 if let Some(tok_json) =
500 fs::read_to_string(&tok_path).ok().and_then(|s| {
501 serde_json::from_str::<serde_json::Value>(&s).ok()
502 })
503 {
504 let added = tok_json
505 .get("added_tokens")
506 .and_then(serde_json::Value::as_array);
507 for token in added.into_iter().flatten() {
508 let content = token
509 .get("content")
510 .and_then(serde_json::Value::as_str)
511 .unwrap_or("");
512 let special = token
513 .get("special")
514 .and_then(serde_json::Value::as_bool)
515 .unwrap_or(false);
516 if special {
517 if content == "<bos>" {
518 ct.bos_token = Some(BeginEndUnkPadTok(
519 Either::Left(content.to_string()),
520 ));
521 } else if content == "<eos>" {
522 ct.eos_token = Some(BeginEndUnkPadTok(
523 Either::Left(content.to_string()),
524 ));
525 } else if content == "<unk>" {
526 ct.unk_token = Some(BeginEndUnkPadTok(
527 Either::Left(content.to_string()),
528 ));
529 }
530 }
531 }
532 }
533 }
534 ct
535 });
536 template.chat_template =
537 Some(ChatTemplateValue(Either::Left(content.clone())));
538 template
539 } else {
540 serde_json::from_str(content).unwrap()
541 }
542 } else {
543 serde_json::from_str(content).unwrap()
544 }
545 } else {
546 ChatTemplate::default()
549 }
550 }
551 };
552 if template.chat_template.is_none() {
554 if let Some(chat_template_explicit) = chat_template_explicit {
555 let ct =
556 fs::read_to_string(chat_template_explicit).expect("Loading chat template failed.");
557
558 let new_chat_template = if chat_template_explicit.ends_with(".jinja") {
559 ct
560 } else {
561 #[derive(Debug, serde::Deserialize)]
562 struct AutomaticTemplate {
563 chat_template: String,
564 }
565 let deser: AutomaticTemplate = serde_json::from_str(&ct).unwrap();
566 deser.chat_template
567 };
568
569 template.chat_template = Some(ChatTemplateValue(Either::Left(new_chat_template)));
570 }
571 }
572
573 if let Some(jinja_explicit) = jinja_explicit {
575 if !jinja_explicit.ends_with(".jinja") {
576 panic!("jinja_explicit must end with .jinja!");
577 }
578
579 let ct = fs::read_to_string(jinja_explicit).expect("Loading chat template failed.");
580
581 template.chat_template = Some(ChatTemplateValue(Either::Left(ct)));
582 }
583
584 let processor_conf: Option<crate::vision_models::processor_config::ProcessorConfig> = paths
585 .get_processor_config()
586 .as_ref()
587 .map(|f| serde_json::from_str(&fs::read_to_string(f).unwrap()).unwrap());
588 if let Some(processor_conf) = processor_conf {
589 if processor_conf.chat_template.is_some() {
590 template.chat_template = processor_conf
591 .chat_template
592 .map(|x| ChatTemplateValue(Either::Left(x)));
593 }
594 }
595
596 #[derive(Debug, serde::Deserialize)]
597 struct SpecifiedTemplate {
598 chat_template: String,
599 bos_token: Option<String>,
600 eos_token: Option<String>,
601 unk_token: Option<String>,
602 }
603
604 if template.chat_template.is_some() {
605 return template;
606 };
607
608 match &template.chat_template {
609 Some(_) => template,
610 None => {
611 if let Some(template_content) = template_content {
612 info!("`tokenizer_config.json` does not contain a chat template, attempting to use specified JINJA chat template.");
613 let mut deser: HashMap<String, Value> =
614 serde_json::from_str(&template_content).unwrap();
615
616 match chat_template_fallback.cloned() {
617 Some(t) => {
618 info!("Loading specified loading chat template file at `{t}`.");
619 let templ: SpecifiedTemplate =
620 serde_json::from_str(&fs::read_to_string(t.clone()).unwrap()).unwrap();
621 deser.insert(
622 "chat_template".to_string(),
623 Value::String(templ.chat_template),
624 );
625 if let Some(bos_token) = templ.bos_token {
626 deser.insert("bos_token".to_string(), Value::String(bos_token));
627 }
628 if let Some(eos_token) = templ.eos_token {
629 deser.insert("eos_token".to_string(), Value::String(eos_token));
630 }
631 if let Some(unk_token) = templ.unk_token {
632 deser.insert("unk_token".to_string(), Value::String(unk_token));
633 }
634 }
635 None => {
636 warn!("No specified chat template. No chat template will be used. Only prompts will be accepted, not messages.");
637 deser.insert("chat_template".to_string(), Value::Null);
638 }
639 }
640
641 let ser = serde_json::to_string_pretty(&deser)
642 .expect("Serialization of modified chat template failed.");
643 serde_json::from_str(&ser).unwrap()
644 } else {
645 warn!("No chat template source found. No chat template will be used. Only prompts will be accepted, not messages.");
646 template
647 }
648 }
649 }
650}
651
652mod tests {
653 #[test]
654 fn match_safetensors() -> anyhow::Result<()> {
655 use regex_automata::meta::Regex;
656
657 use super::SAFETENSOR_MATCH;
658 let safetensor_match = Regex::new(SAFETENSOR_MATCH)?;
659
660 let positive_ids = [
661 "model-00001-of-00001.safetensors",
662 "model-00002-of-00002.safetensors",
663 "model-00003-of-00003.safetensors",
664 "model-00004-of-00004.safetensors",
665 "model-00005-of-00005.safetensors",
666 "model-00006-of-00006.safetensors",
667 ];
668 let negative_ids = [
669 "model-0000a-of-00002.safetensors",
670 "consolidated.safetensors",
671 ];
672 for id in positive_ids {
673 assert!(safetensor_match.is_match(id));
674 }
675 for id in negative_ids {
676 assert!(!safetensor_match.is_match(id));
677 }
678 Ok(())
679 }
680
681 #[test]
682 fn match_pickle() -> anyhow::Result<()> {
683 use regex_automata::meta::Regex;
684
685 use super::PICKLE_MATCH;
686 let pickle_match = Regex::new(PICKLE_MATCH)?;
687
688 let positive_ids = [
689 "pytorch_model-00001-of-00002.bin",
690 "pytorch_model-00002-of-00002.bin",
691 ];
692 let negative_ids = [
693 "pytorch_model-000001-of-00001.bin",
694 "pytorch_model-0000a-of-00002.bin",
695 "pytorch_model-000-of-00003.bin",
696 "pytorch_consolidated.bin",
697 ];
698 for id in positive_ids {
699 assert!(pickle_match.is_match(id));
700 }
701 for id in negative_ids {
702 assert!(!pickle_match.is_match(id));
703 }
704 Ok(())
705 }
706}