Skip to main content

hanzo_engine/pipeline/
paths.rs

1use std::{
2    collections::HashMap,
3    fs,
4    path::{Path, PathBuf},
5};
6
7use anyhow::Result;
8use either::Either;
9use hf_hub::{
10    api::sync::{ApiBuilder, ApiRepo},
11    Repo, RepoType,
12};
13use regex_automata::meta::Regex;
14use serde_json::Value;
15use tracing::{debug, info, trace, warn};
16
17use crate::{
18    api_dir_list, api_get_file,
19    lora::LoraConfig,
20    pipeline::{
21        chat_template::{BeginEndUnkPadTok, ChatTemplate, ChatTemplateValue},
22        isq::UQFF_RESIDUAL_SAFETENSORS,
23    },
24    utils::tokens::get_token,
25    xlora_models::XLoraConfig,
26    ModelPaths, Ordering, TokenSource, GLOBAL_HF_CACHE,
27};
28
29// Match files against these
30const SAFETENSOR_MATCH: &str = r"model-\d+-of-\d+\.safetensors\b";
31const QUANT_SAFETENSOR_MATCH: &str = r"model\.safetensors\b";
32const CONSOLIDATED_SAFETENSOR_MATCH: &str = r"consolidated\.safetensors\b";
33const PICKLE_MATCH: &str = r"pytorch_model-\d{5}-of-\d{5}.((pth)|(pt)|(bin))\b";
34
35#[derive(Clone, Debug)]
36pub struct LoraAdapterPaths {
37    pub lora_config: hanzo_quant::LoraConfig,
38    pub adapter_path: PathBuf,
39}
40
41#[allow(clippy::large_enum_variant)]
42#[derive(Clone, Debug)]
43pub enum AdapterPaths {
44    XLora {
45        adapter_configs: Option<Vec<((String, String), LoraConfig)>>,
46        adapter_safetensors: Option<Vec<(String, PathBuf)>>,
47        classifier_path: Option<PathBuf>,
48        xlora_order: Option<Ordering>,
49        xlora_config: Option<XLoraConfig>,
50        lora_preload_adapter_info: Option<HashMap<String, (PathBuf, LoraConfig)>>,
51    },
52    Lora(Vec<LoraAdapterPaths>),
53    None,
54}
55
56pub fn get_xlora_paths(
57    base_model_id: String,
58    xlora_model_id: Option<&String>,
59    lora_adapter_ids: Option<&Vec<String>>,
60    token_source: &TokenSource,
61    revision: String,
62    xlora_order: Option<&Ordering>,
63) -> Result<AdapterPaths> {
64    match (lora_adapter_ids, xlora_model_id, xlora_order) {
65        (None, Some(xlora_id), Some(xlora_order)) => {
66            let api = {
67                let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
68                let mut api = ApiBuilder::from_cache(cache)
69                    .with_progress(true)
70                    .with_token(get_token(token_source)?);
71                if let Some(cache_dir) = crate::hf_hub_cache_dir() {
72                    api = api.with_cache_dir(cache_dir);
73                }
74                api.build().map_err(hanzo_ml::Error::msg)?
75            };
76            let api = api.repo(Repo::with_revision(
77                xlora_id.clone(),
78                RepoType::Model,
79                revision.clone(),
80            ));
81            let model_id = Path::new(&xlora_id);
82            let dir_list = api_dir_list!(api, model_id, true, &revision).collect::<Vec<_>>();
83            // Get the path for the xlora classifier
84            let xlora_classifier = &dir_list
85                .clone()
86                .into_iter()
87                .filter(|x| x.contains("xlora_classifier.safetensors"))
88                .collect::<Vec<_>>();
89            if xlora_classifier.len() > 1 {
90                warn!("Detected multiple X-LoRA classifiers: {xlora_classifier:?}");
91                warn!("Selected classifier: `{}`", &xlora_classifier[0]);
92            }
93            let xlora_classifier = xlora_classifier.first();
94
95            let classifier_path = xlora_classifier
96                .map(|xlora_classifier| -> hanzo_ml::Result<_> {
97                    Ok(api_get_file!(api, xlora_classifier, model_id, &revision))
98                })
99                .transpose()?;
100
101            // Get the path for the xlora config by checking all for valid versions.
102            // NOTE(hanzoai): Remove this functionality because all configs should be deserializable
103            let xlora_configs = &dir_list
104                .clone()
105                .into_iter()
106                .filter(|x| x.contains("xlora_config.json"))
107                .collect::<Vec<_>>();
108            if xlora_configs.len() > 1 {
109                warn!("Detected multiple X-LoRA configs: {xlora_configs:?}");
110            }
111
112            let mut xlora_config: Option<XLoraConfig> = None;
113            let mut last_err: Option<serde_json::Error> = None;
114            for (i, config_path) in xlora_configs.iter().enumerate() {
115                if xlora_configs.len() != 1 {
116                    warn!("Selecting config: `{}`", config_path);
117                }
118                let config_path = api_get_file!(api, config_path, model_id, &revision);
119                let conf = fs::read_to_string(config_path)?;
120                let deser: Result<XLoraConfig, serde_json::Error> = serde_json::from_str(&conf);
121                match deser {
122                    Ok(conf) => {
123                        xlora_config = Some(conf);
124                        break;
125                    }
126                    Err(e) => {
127                        if i != xlora_configs.len() - 1 {
128                            warn!("Config is broken with error `{e}`");
129                        }
130                        last_err = Some(e);
131                    }
132                }
133            }
134            let xlora_config = xlora_config.map(Some).unwrap_or_else(|| {
135                if let Some(last_err) = last_err {
136                    panic!("Unable to derserialize any configs. Last error: {last_err}")
137                } else {
138                    None
139                }
140            });
141
142            // If there are adapters in the ordering file, get their names and remote paths
143            let adapter_files = dir_list
144                .into_iter()
145                .filter_map(|name| {
146                    if let Some(ref adapters) = xlora_order.adapters {
147                        for adapter_name in adapters {
148                            if name.contains(adapter_name) {
149                                return Some((name, adapter_name.clone()));
150                            }
151                        }
152                    }
153                    None
154                })
155                .collect::<Vec<_>>();
156            if adapter_files.is_empty() && xlora_order.adapters.is_some() {
157                anyhow::bail!("Adapter files are empty. Perhaps the ordering file adapters does not match the actual adapters?")
158            }
159
160            // Get the local paths for each adapter
161            let mut adapters_paths: HashMap<String, Vec<PathBuf>> = HashMap::new();
162            for (file, name) in adapter_files {
163                if let Some(paths) = adapters_paths.get_mut(&name) {
164                    paths.push(api_get_file!(api, &file, model_id, &revision));
165                } else {
166                    adapters_paths
167                        .insert(name, vec![api_get_file!(api, &file, model_id, &revision)]);
168                }
169            }
170
171            // Sort local paths for the adapter configs and safetensors files
172            let mut adapters_configs = Vec::new();
173            let mut adapters_safetensors = Vec::new();
174            if let Some(ref adapters) = xlora_order.adapters {
175                for (i, name) in adapters.iter().enumerate() {
176                    let paths = adapters_paths
177                        .get(name)
178                        .unwrap_or_else(|| panic!("Adapter {name} not found."));
179                    for path in paths {
180                        if path.extension().unwrap() == "safetensors" {
181                            adapters_safetensors.push((name.clone(), path.to_owned()));
182                        } else {
183                            let conf = fs::read_to_string(path)?;
184                            let lora_config: LoraConfig = serde_json::from_str(&conf)?;
185                            adapters_configs
186                                .push((((i + 1).to_string(), name.clone()), lora_config));
187                        }
188                    }
189                }
190            }
191
192            // Make sure they all match
193            if xlora_order.base_model_id
194                != *xlora_config
195                    .as_ref()
196                    .map(|cfg| &cfg.base_model_id)
197                    .unwrap_or(&base_model_id)
198                || xlora_config
199                    .as_ref()
200                    .map(|cfg| &cfg.base_model_id)
201                    .unwrap_or(&base_model_id)
202                    != &base_model_id
203            {
204                anyhow::bail!(
205                    "Adapter ordering file, adapter model config, and base model ID do not match: {}, {}, and {} respectively.",
206                    xlora_order.base_model_id,
207                    xlora_config.map(|cfg| cfg.base_model_id).unwrap_or(base_model_id.clone()),
208                    base_model_id
209                );
210            }
211
212            let lora_preload_adapter_info =
213                // If preload adapters are specified, get their metadata like above
214                if let Some(preload_adapters) = &xlora_order.preload_adapters {
215                    let mut output = HashMap::new();
216                    for adapter in preload_adapters {
217                        // Get the names and remote paths of the files associated with this adapter
218                        let adapter_files = api_dir_list!(api, &adapter.adapter_model_id, true, &revision)
219                            .filter_map(|f| {
220                                if f.contains(&adapter.name) {
221                                    Some((f, adapter.name.clone()))
222                                } else {
223                                    None
224                                }
225                            })
226                            .collect::<Vec<_>>();
227                        if adapter_files.is_empty() {
228                            anyhow::bail!("Adapter files are empty. Perhaps the ordering file adapters does not match the actual adapters?")
229                        }
230                        // Get local paths for this adapter
231                        let mut adapters_paths: HashMap<String, Vec<PathBuf>> = HashMap::new();
232                        for (file, name) in adapter_files {
233                            if let Some(paths) = adapters_paths.get_mut(&name) {
234                                paths.push(api_get_file!(api, &file, model_id, &revision));
235                            } else {
236                                adapters_paths
237                                    .insert(name, vec![api_get_file!(api, &file, model_id, &revision)]);
238                            }
239                        }
240
241                        let mut config = None;
242                        let mut safetensor = None;
243
244                        // Sort local paths for the adapter configs and safetensors files
245                        let paths = adapters_paths
246                            .get(&adapter.name)
247                            .unwrap_or_else(|| panic!("Adapter {} not found.", adapter.name));
248                        for path in paths {
249                            if path.extension().unwrap() == "safetensors" {
250                                safetensor = Some(path.to_owned());
251                            } else {
252                                let conf = fs::read_to_string(path)?;
253                                let lora_config: LoraConfig = serde_json::from_str(&conf)?;
254                                config = Some(lora_config);
255                            }
256                        }
257
258                        let (config, safetensor) = (config.unwrap(), safetensor.unwrap());
259                        output.insert(adapter.name.clone(), (safetensor, config));
260                    }
261                    Some(output)
262                } else {
263                    None
264                };
265
266            Ok(AdapterPaths::XLora {
267                adapter_configs: Some(adapters_configs),
268                adapter_safetensors: Some(adapters_safetensors),
269                classifier_path,
270                xlora_order: Some(xlora_order.clone()),
271                xlora_config,
272                lora_preload_adapter_info,
273            })
274        }
275        (Some(adapter_ids), None, None) => {
276            let mut lora_adapter_paths = Vec::new();
277            for adapter_id in adapter_ids {
278                info!("Loading adapter at `{adapter_id}`");
279
280                let api = {
281                    let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
282                    let mut api = ApiBuilder::from_cache(cache)
283                        .with_progress(true)
284                        .with_token(get_token(token_source)?);
285                    if let Some(cache_dir) = crate::hf_hub_cache_dir() {
286                        api = api.with_cache_dir(cache_dir);
287                    }
288                    api.build().map_err(hanzo_ml::Error::msg)?
289                };
290                let api = api.repo(Repo::with_revision(
291                    adapter_id.clone(),
292                    RepoType::Model,
293                    revision.clone(),
294                ));
295
296                let adapter_path_buf = std::path::Path::new(adapter_id);
297                let config_path = crate::pipeline::hf::get_file(
298                    &api,
299                    adapter_path_buf,
300                    "adapter_config.json",
301                    &revision,
302                )?;
303                let adapter_path = crate::pipeline::hf::get_file(
304                    &api,
305                    adapter_path_buf,
306                    "adapter_model.safetensors",
307                    &revision,
308                )?;
309                let lora_config: hanzo_quant::LoraConfig =
310                    serde_json::from_str(&fs::read_to_string(config_path)?)?;
311
312                lora_adapter_paths.push(LoraAdapterPaths {
313                    lora_config,
314                    adapter_path,
315                });
316            }
317
318            Ok(AdapterPaths::Lora(lora_adapter_paths))
319        }
320        (None, None, None) => Ok(AdapterPaths::None),
321        _ => anyhow::bail!(
322            "Incorrect configuration for an adapter model. Lora and XLora are mutually exclusive."
323        ),
324    }
325}
326
327pub fn get_model_paths(
328    revision: String,
329    token_source: &TokenSource,
330    quantized_model_id: Option<&String>,
331    quantized_filename: Option<&Vec<String>>,
332    api: &ApiRepo,
333    model_id: &Path,
334    loading_from_uqff: bool,
335) -> Result<Vec<PathBuf>> {
336    match quantized_filename {
337        Some(names) => {
338            let id = quantized_model_id.unwrap();
339            let mut files = Vec::new();
340
341            for name in names {
342                let qapi = {
343                    let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
344                    let mut api = ApiBuilder::from_cache(cache)
345                        .with_progress(true)
346                        .with_token(get_token(token_source)?);
347                    if let Some(cache_dir) = crate::hf_hub_cache_dir() {
348                        api = api.with_cache_dir(cache_dir);
349                    }
350                    api.build().map_err(hanzo_ml::Error::msg)?
351                };
352                let qapi = qapi.repo(Repo::with_revision(
353                    id.to_string(),
354                    RepoType::Model,
355                    revision.clone(),
356                ));
357                let model_id = Path::new(&id);
358                files.push(api_get_file!(qapi, name, model_id, &revision));
359            }
360            Ok(files)
361        }
362        None => {
363            // We only match these patterns for model names
364            let safetensor_match = Regex::new(SAFETENSOR_MATCH)?;
365            let quant_safetensor_match = Regex::new(QUANT_SAFETENSOR_MATCH)?;
366            let consolidated_safetensor_match = Regex::new(CONSOLIDATED_SAFETENSOR_MATCH)?;
367            let pickle_match = Regex::new(PICKLE_MATCH)?;
368
369            let mut filenames = vec![];
370            let listing = api_dir_list!(api, model_id, true, &revision).filter(|x| {
371                safetensor_match.is_match(x)
372                    || pickle_match.is_match(x)
373                    || quant_safetensor_match.is_match(x)
374                    || consolidated_safetensor_match.is_match(x)
375                    || x == UQFF_RESIDUAL_SAFETENSORS
376            });
377            let safetensors = listing
378                .clone()
379                .filter(|x| x.ends_with(".safetensors"))
380                .collect::<Vec<_>>();
381            let pickles = listing
382                .clone()
383                .filter(|x| x.ends_with(".pth") || x.ends_with(".pt") || x.ends_with(".bin"))
384                .collect::<Vec<_>>();
385            let uqff_residual = listing
386                .clone()
387                .filter(|x| x == UQFF_RESIDUAL_SAFETENSORS)
388                .collect::<Vec<_>>();
389            let files = if !safetensors.is_empty() {
390                // Always prefer safetensors
391                safetensors
392            } else if !pickles.is_empty() {
393                // Fall back to pickle
394                pickles
395            } else if !uqff_residual.is_empty() && loading_from_uqff {
396                uqff_residual
397            } else {
398                anyhow::bail!("Expected file with extension one of .safetensors, .pth, .pt, .bin.");
399            };
400            trace!(
401                "Found model weight filenames {:?}",
402                files
403                    .iter()
404                    .map(|x| x.split('/').next_back().unwrap())
405                    .collect::<Vec<_>>()
406            );
407            for rfilename in files {
408                filenames.push(api_get_file!(api, &rfilename, model_id, &revision));
409            }
410            Ok(filenames)
411        }
412    }
413}
414
415/// Find and parse the appropriate [`ChatTemplate`], and ensure is has a valid [`ChatTemplate.chat_template`].
416/// If the provided `tokenizer_config.json` from [`ModelPaths.get_template_filename`] does not
417/// have a `chat_template`, use the provided one.
418///
419/// - Uses `chat_template_fallback` if `paths` does not contain a chat template file. This may be a literal or .json file.
420/// - `chat_template_ovrd` (GGUF chat template content) causes the usage of that string chat template initially.
421///   Falls back to `chat_template_file` if it is invalid. *The user must add the bos/unk/eos tokens manually if this
422///   is used.*
423///
424/// THE FOLLOWING IS IGNORED:
425/// After this, if the `chat_template_explicit` filename is specified (a json with one field: "chat_template" OR a jinja file),
426///  the chat template is overwritten with this chat template.
427#[allow(clippy::borrowed_box)]
428pub(crate) fn get_chat_template(
429    paths: &Box<dyn ModelPaths>,
430    jinja_explicit: Option<&String>,
431    chat_template_explicit: Option<&String>,
432    chat_template_fallback: Option<&String>,
433    chat_template_ovrd: Option<String>,
434) -> ChatTemplate {
435    // Get template content, this may be overridden.
436    let template_content = if let Some(template_filename) = paths.get_template_filename() {
437        if !["jinja", "json"].contains(
438            &template_filename
439                .extension()
440                .expect("Template filename must be a file")
441                .to_string_lossy()
442                .to_string()
443                .as_str(),
444        ) {
445            panic!("Template filename {template_filename:?} must end with `.json` or `.jinja`.");
446        }
447        Some(fs::read_to_string(template_filename).expect("Loading chat template failed."))
448    } else if chat_template_fallback.is_some_and(|f| f.ends_with(".json")) {
449        // User specified a file
450        let template_filename = chat_template_fallback
451            .expect("A tokenizer config or chat template file path must be specified.");
452        Some(fs::read_to_string(template_filename).expect("Loading chat template failed."))
453    } else if chat_template_ovrd.is_some() {
454        None
455    } else {
456        debug!("No chat template file found. Chat template may be set via `chat_template.json` or processor config.");
457        None
458    };
459    let mut template: ChatTemplate = match chat_template_ovrd {
460        Some(chat_template) => {
461            // In this case the override chat template is being used. The user must add the bos/eos/unk toks themselves.
462            debug!("Using literal chat template.");
463            let mut template = ChatTemplate::default();
464            template.chat_template = Some(ChatTemplateValue(Either::Left(chat_template)));
465            template
466        }
467        None => {
468            if let Some(ref content) = template_content {
469                // Check if template_filename is a .jinja file
470                if let Some(template_filename) = paths.get_template_filename() {
471                    if template_filename.extension().map(|e| e.to_str()) == Some(Some("jinja")) {
472                        debug!("Using chat template from .jinja file.");
473                        // Load special tokens (bos/eos/unk) from tokenizer_config.json
474                        // in the same directory, matching HF's behavior where
475                        // apply_chat_template passes self.special_tokens_map to the template.
476                        let mut template = template_filename
477                            .parent()
478                            .map(|dir| dir.join("tokenizer_config.json"))
479                            .filter(|p| p.exists())
480                            .and_then(|p| fs::read_to_string(p).ok())
481                            .and_then(|s| serde_json::from_str::<ChatTemplate>(&s).ok())
482                            .unwrap_or_else(|| {
483                                // Fallback: older UQFF repos may not have tokenizer_config.json.
484                                // Try to extract bos/eos tokens from the tokenizer.json's
485                                // added_tokens list to avoid rendering "none" in the template.
486                                let mut ct = ChatTemplate::default();
487                                if let Some(tok_path) = paths
488                                    .get_tokenizer_filename()
489                                    .parent()
490                                    .map(|d| d.join("tokenizer.json"))
491                                    .filter(|p| p.exists())
492                                    .or_else(|| {
493                                        template_filename
494                                            .parent()
495                                            .map(|d| d.join("tokenizer.json"))
496                                            .filter(|p| p.exists())
497                                    })
498                                {
499                                    if let Some(tok_json) =
500                                        fs::read_to_string(&tok_path).ok().and_then(|s| {
501                                            serde_json::from_str::<serde_json::Value>(&s).ok()
502                                        })
503                                    {
504                                        let added = tok_json
505                                            .get("added_tokens")
506                                            .and_then(serde_json::Value::as_array);
507                                        for token in added.into_iter().flatten() {
508                                            let content = token
509                                                .get("content")
510                                                .and_then(serde_json::Value::as_str)
511                                                .unwrap_or("");
512                                            let special = token
513                                                .get("special")
514                                                .and_then(serde_json::Value::as_bool)
515                                                .unwrap_or(false);
516                                            if special {
517                                                if content == "<bos>" {
518                                                    ct.bos_token = Some(BeginEndUnkPadTok(
519                                                        Either::Left(content.to_string()),
520                                                    ));
521                                                } else if content == "<eos>" {
522                                                    ct.eos_token = Some(BeginEndUnkPadTok(
523                                                        Either::Left(content.to_string()),
524                                                    ));
525                                                } else if content == "<unk>" {
526                                                    ct.unk_token = Some(BeginEndUnkPadTok(
527                                                        Either::Left(content.to_string()),
528                                                    ));
529                                                }
530                                            }
531                                        }
532                                    }
533                                }
534                                ct
535                            });
536                        template.chat_template =
537                            Some(ChatTemplateValue(Either::Left(content.clone())));
538                        template
539                    } else {
540                        serde_json::from_str(content).unwrap()
541                    }
542                } else {
543                    serde_json::from_str(content).unwrap()
544                }
545            } else {
546                // No template content available; downstream code may fill in from
547                // chat_template.json, processor_config, or jinja_explicit.
548                ChatTemplate::default()
549            }
550        }
551    };
552    // Overwrite to use any present `chat_template.json`, only if there is not one present already.
553    if template.chat_template.is_none() {
554        if let Some(chat_template_explicit) = chat_template_explicit {
555            let ct =
556                fs::read_to_string(chat_template_explicit).expect("Loading chat template failed.");
557
558            let new_chat_template = if chat_template_explicit.ends_with(".jinja") {
559                ct
560            } else {
561                #[derive(Debug, serde::Deserialize)]
562                struct AutomaticTemplate {
563                    chat_template: String,
564                }
565                let deser: AutomaticTemplate = serde_json::from_str(&ct).unwrap();
566                deser.chat_template
567            };
568
569            template.chat_template = Some(ChatTemplateValue(Either::Left(new_chat_template)));
570        }
571    }
572
573    // JINJA explicit
574    if let Some(jinja_explicit) = jinja_explicit {
575        if !jinja_explicit.ends_with(".jinja") {
576            panic!("jinja_explicit must end with .jinja!");
577        }
578
579        let ct = fs::read_to_string(jinja_explicit).expect("Loading chat template failed.");
580
581        template.chat_template = Some(ChatTemplateValue(Either::Left(ct)));
582    }
583
584    let processor_conf: Option<crate::vision_models::processor_config::ProcessorConfig> = paths
585        .get_processor_config()
586        .as_ref()
587        .map(|f| serde_json::from_str(&fs::read_to_string(f).unwrap()).unwrap());
588    if let Some(processor_conf) = processor_conf {
589        if processor_conf.chat_template.is_some() {
590            template.chat_template = processor_conf
591                .chat_template
592                .map(|x| ChatTemplateValue(Either::Left(x)));
593        }
594    }
595
596    #[derive(Debug, serde::Deserialize)]
597    struct SpecifiedTemplate {
598        chat_template: String,
599        bos_token: Option<String>,
600        eos_token: Option<String>,
601        unk_token: Option<String>,
602    }
603
604    if template.chat_template.is_some() {
605        return template;
606    };
607
608    match &template.chat_template {
609        Some(_) => template,
610        None => {
611            if let Some(template_content) = template_content {
612                info!("`tokenizer_config.json` does not contain a chat template, attempting to use specified JINJA chat template.");
613                let mut deser: HashMap<String, Value> =
614                    serde_json::from_str(&template_content).unwrap();
615
616                match chat_template_fallback.cloned() {
617                    Some(t) => {
618                        info!("Loading specified loading chat template file at `{t}`.");
619                        let templ: SpecifiedTemplate =
620                            serde_json::from_str(&fs::read_to_string(t.clone()).unwrap()).unwrap();
621                        deser.insert(
622                            "chat_template".to_string(),
623                            Value::String(templ.chat_template),
624                        );
625                        if let Some(bos_token) = templ.bos_token {
626                            deser.insert("bos_token".to_string(), Value::String(bos_token));
627                        }
628                        if let Some(eos_token) = templ.eos_token {
629                            deser.insert("eos_token".to_string(), Value::String(eos_token));
630                        }
631                        if let Some(unk_token) = templ.unk_token {
632                            deser.insert("unk_token".to_string(), Value::String(unk_token));
633                        }
634                    }
635                    None => {
636                        warn!("No specified chat template. No chat template will be used. Only prompts will be accepted, not messages.");
637                        deser.insert("chat_template".to_string(), Value::Null);
638                    }
639                }
640
641                let ser = serde_json::to_string_pretty(&deser)
642                    .expect("Serialization of modified chat template failed.");
643                serde_json::from_str(&ser).unwrap()
644            } else {
645                warn!("No chat template source found. No chat template will be used. Only prompts will be accepted, not messages.");
646                template
647            }
648        }
649    }
650}
651
652mod tests {
653    #[test]
654    fn match_safetensors() -> anyhow::Result<()> {
655        use regex_automata::meta::Regex;
656
657        use super::SAFETENSOR_MATCH;
658        let safetensor_match = Regex::new(SAFETENSOR_MATCH)?;
659
660        let positive_ids = [
661            "model-00001-of-00001.safetensors",
662            "model-00002-of-00002.safetensors",
663            "model-00003-of-00003.safetensors",
664            "model-00004-of-00004.safetensors",
665            "model-00005-of-00005.safetensors",
666            "model-00006-of-00006.safetensors",
667        ];
668        let negative_ids = [
669            "model-0000a-of-00002.safetensors",
670            "consolidated.safetensors",
671        ];
672        for id in positive_ids {
673            assert!(safetensor_match.is_match(id));
674        }
675        for id in negative_ids {
676            assert!(!safetensor_match.is_match(id));
677        }
678        Ok(())
679    }
680
681    #[test]
682    fn match_pickle() -> anyhow::Result<()> {
683        use regex_automata::meta::Regex;
684
685        use super::PICKLE_MATCH;
686        let pickle_match = Regex::new(PICKLE_MATCH)?;
687
688        let positive_ids = [
689            "pytorch_model-00001-of-00002.bin",
690            "pytorch_model-00002-of-00002.bin",
691        ];
692        let negative_ids = [
693            "pytorch_model-000001-of-00001.bin",
694            "pytorch_model-0000a-of-00002.bin",
695            "pytorch_model-000-of-00003.bin",
696            "pytorch_consolidated.bin",
697        ];
698        for id in positive_ids {
699            assert!(pickle_match.is_match(id));
700        }
701        for id in negative_ids {
702            assert!(!pickle_match.is_match(id));
703        }
704        Ok(())
705    }
706}