Skip to main content

hanzo_engine/
model_loader.rs

1use std::{
2    fs::{self, File},
3    path::PathBuf,
4    str::FromStr,
5};
6
7use hanzo_quant::MULTI_LORA_DELIMITER;
8
9use crate::{
10    get_toml_selected_model_dtype,
11    pipeline::{
12        AutoLoaderBuilder, DiffusionLoaderBuilder, GGMLLoaderBuilder, GGMLSpecificConfig,
13        GGUFLoaderBuilder, GGUFSpecificConfig, MultimodalLoaderBuilder, MultimodalSpecificConfig,
14        NormalLoaderBuilder, NormalSpecificConfig,
15    },
16    toml_selector::get_toml_selected_model_device_map_params,
17    AutoDeviceMapParams, EmbeddingLoaderBuilder, EmbeddingSpecificConfig, Loader, ModelDType,
18    ModelSelected, SpeechLoader, TomlLoaderArgs, TomlSelector, Topology, GGUF_MULTI_FILE_DELIMITER,
19    UQFF_MULTI_FILE_DELIMITER,
20};
21
22/// A builder for a loader using the selected model.
23pub struct LoaderBuilder {
24    model: ModelSelected,
25    no_kv_cache: bool,
26    chat_template: Option<String>,
27    jinja_explicit: Option<String>,
28}
29
30impl LoaderBuilder {
31    pub fn new(model: ModelSelected) -> Self {
32        Self {
33            model,
34            no_kv_cache: false,
35            chat_template: None,
36            jinja_explicit: None,
37        }
38    }
39
40    pub fn with_no_kv_cache(mut self, no_kv_cache: bool) -> Self {
41        self.no_kv_cache = no_kv_cache;
42        self
43    }
44    pub fn with_chat_template(mut self, chat_template: Option<String>) -> Self {
45        self.chat_template = chat_template;
46        self
47    }
48    pub fn with_jinja_explicit(mut self, jinja_explicit: Option<String>) -> Self {
49        self.jinja_explicit = jinja_explicit;
50        self
51    }
52
53    pub fn build(self) -> anyhow::Result<Box<dyn Loader>> {
54        loader_from_model_selected(self)
55    }
56}
57
58pub fn get_tgt_non_granular_index(model: &ModelSelected) -> Option<usize> {
59    match model {
60        ModelSelected::Plain { .. }
61        | ModelSelected::Run { .. }
62        | ModelSelected::Lora { .. }
63        | ModelSelected::GGUF { .. }
64        | ModelSelected::LoraGGUF { .. }
65        | ModelSelected::GGML { .. }
66        | ModelSelected::LoraGGML { .. }
67        | ModelSelected::Toml { .. }
68        | ModelSelected::MultimodalPlain { .. }
69        | ModelSelected::DiffusionPlain { .. }
70        | ModelSelected::Speech { .. }
71        | ModelSelected::Embedding { .. } => None,
72        ModelSelected::XLora {
73            tgt_non_granular_index,
74            ..
75        }
76        | ModelSelected::XLoraGGUF {
77            tgt_non_granular_index,
78            ..
79        }
80        | ModelSelected::XLoraGGML {
81            tgt_non_granular_index,
82            ..
83        } => *tgt_non_granular_index,
84        ModelSelected::MultiModel { .. } => {
85            panic!("MultiModel variant should not be used in model loading functions")
86        }
87    }
88}
89
90pub fn get_model_dtype(model: &ModelSelected) -> anyhow::Result<ModelDType> {
91    match model {
92        ModelSelected::Plain { dtype, .. }
93        | ModelSelected::Lora { dtype, .. }
94        | ModelSelected::XLora { dtype, .. }
95        | ModelSelected::MultimodalPlain { dtype, .. }
96        | ModelSelected::DiffusionPlain { dtype, .. }
97        | ModelSelected::GGML { dtype, .. }
98        | ModelSelected::GGUF { dtype, .. }
99        | ModelSelected::XLoraGGUF { dtype, .. }
100        | ModelSelected::XLoraGGML { dtype, .. }
101        | ModelSelected::LoraGGUF { dtype, .. }
102        | ModelSelected::LoraGGML { dtype, .. }
103        | ModelSelected::Run { dtype, .. }
104        | ModelSelected::Speech { dtype, .. }
105        | ModelSelected::Embedding { dtype, .. } => Ok(*dtype),
106        ModelSelected::Toml { file } => {
107            let selector: TomlSelector = toml::from_str(
108                &fs::read_to_string(file.clone())
109                    .unwrap_or_else(|_| panic!("Could not load toml selector file at {file}")),
110            )?;
111            Ok(get_toml_selected_model_dtype(&selector))
112        }
113        ModelSelected::MultiModel { .. } => {
114            anyhow::bail!("MultiModel variant should not be used in model loading functions")
115        }
116    }
117}
118
119pub fn get_auto_device_map_params(model: &ModelSelected) -> anyhow::Result<AutoDeviceMapParams> {
120    match model {
121        ModelSelected::Plain {
122            max_seq_len,
123            max_batch_size,
124            ..
125        }
126        | ModelSelected::Lora {
127            max_seq_len,
128            max_batch_size,
129            ..
130        }
131        | ModelSelected::XLora {
132            max_seq_len,
133            max_batch_size,
134            ..
135        }
136        | ModelSelected::GGML {
137            max_seq_len,
138            max_batch_size,
139            ..
140        }
141        | ModelSelected::GGUF {
142            max_seq_len,
143            max_batch_size,
144            ..
145        }
146        | ModelSelected::XLoraGGUF {
147            max_seq_len,
148            max_batch_size,
149            ..
150        }
151        | ModelSelected::XLoraGGML {
152            max_seq_len,
153            max_batch_size,
154            ..
155        }
156        | ModelSelected::LoraGGUF {
157            max_seq_len,
158            max_batch_size,
159            ..
160        }
161        | ModelSelected::LoraGGML {
162            max_seq_len,
163            max_batch_size,
164            ..
165        } => Ok(AutoDeviceMapParams::Text {
166            max_seq_len: *max_seq_len,
167            max_batch_size: *max_batch_size,
168        }),
169        ModelSelected::Run {
170            max_seq_len,
171            max_batch_size,
172            max_image_length,
173            max_num_images,
174            ..
175        } => {
176            if max_num_images.is_some() || max_image_length.is_some() {
177                let max_image_length =
178                    max_image_length.unwrap_or(AutoDeviceMapParams::DEFAULT_MAX_IMAGE_LENGTH);
179                Ok(AutoDeviceMapParams::Multimodal {
180                    max_seq_len: *max_seq_len,
181                    max_batch_size: *max_batch_size,
182                    max_image_shape: (max_image_length, max_image_length),
183                    max_num_images: max_num_images
184                        .unwrap_or(AutoDeviceMapParams::DEFAULT_MAX_NUM_IMAGES),
185                })
186            } else {
187                Ok(AutoDeviceMapParams::Text {
188                    max_seq_len: *max_seq_len,
189                    max_batch_size: *max_batch_size,
190                })
191            }
192        }
193        ModelSelected::MultimodalPlain {
194            max_seq_len,
195            max_batch_size,
196            max_image_length,
197            max_num_images,
198            ..
199        } => Ok(AutoDeviceMapParams::Multimodal {
200            max_seq_len: *max_seq_len,
201            max_batch_size: *max_batch_size,
202            max_image_shape: (*max_image_length, *max_image_length),
203            max_num_images: *max_num_images,
204        }),
205        ModelSelected::DiffusionPlain { .. }
206        | ModelSelected::Speech { .. }
207        | ModelSelected::Embedding { .. } => Ok(AutoDeviceMapParams::default_text()),
208        ModelSelected::Toml { file } => {
209            let selector: TomlSelector = toml::from_str(
210                &fs::read_to_string(file.clone())
211                    .unwrap_or_else(|_| panic!("Could not load toml selector file at {file}")),
212            )?;
213            get_toml_selected_model_device_map_params(&selector)
214        }
215        ModelSelected::MultiModel { .. } => {
216            anyhow::bail!("MultiModel variant should not be used in model loading functions")
217        }
218    }
219}
220
221fn loader_from_model_selected(args: LoaderBuilder) -> anyhow::Result<Box<dyn Loader>> {
222    let loader: Box<dyn Loader> = match args.model {
223        ModelSelected::Toml { file } => {
224            let selector: TomlSelector = toml::from_str(
225                &fs::read_to_string(file.clone())
226                    .unwrap_or_else(|_| panic!("Could not load toml selector file at {file}")),
227            )?;
228            let args = TomlLoaderArgs {
229                chat_template: args.chat_template,
230                no_kv_cache: args.no_kv_cache,
231                jinja_explicit: args.jinja_explicit,
232            };
233            (selector, args).try_into()?
234        }
235        ModelSelected::Plain {
236            model_id,
237            tokenizer_json,
238            arch,
239            dtype: _,
240            topology,
241            organization,
242            write_uqff,
243            from_uqff,
244            imatrix,
245            calibration_file,
246            max_seq_len: _,
247            max_batch_size: _,
248            hf_cache_path,
249            matformer_config_path,
250            matformer_slice_name,
251        } => NormalLoaderBuilder::new(
252            NormalSpecificConfig {
253                topology: Topology::from_option_path(topology)?,
254                organization: organization.unwrap_or_default(),
255                write_uqff,
256                from_uqff: from_uqff.map(|x| {
257                    x.split(UQFF_MULTI_FILE_DELIMITER)
258                        .map(PathBuf::from_str)
259                        .map(|x| x.unwrap())
260                        .collect::<Vec<_>>()
261                }),
262                imatrix,
263                calibration_file,
264                hf_cache_path,
265                matformer_config_path,
266                matformer_slice_name,
267            },
268            args.chat_template,
269            tokenizer_json,
270            Some(model_id),
271            args.no_kv_cache,
272            args.jinja_explicit,
273        )
274        .build(arch)?,
275        ModelSelected::Run {
276            model_id,
277            tokenizer_json,
278            dtype: _,
279            topology,
280            organization,
281            write_uqff,
282            from_uqff,
283            imatrix,
284            calibration_file,
285            max_edge,
286            max_seq_len: _,
287            max_batch_size: _,
288            max_num_images: _,
289            max_image_length: _,
290            hf_cache_path,
291            matformer_config_path,
292            matformer_slice_name,
293        } => {
294            let builder = AutoLoaderBuilder::new(
295                NormalSpecificConfig {
296                    topology: Topology::from_option_path(topology.clone())?,
297                    organization: organization.unwrap_or_default(),
298                    write_uqff: write_uqff.clone(),
299                    from_uqff: from_uqff.clone().map(|x| {
300                        x.split(UQFF_MULTI_FILE_DELIMITER)
301                            .map(PathBuf::from_str)
302                            .map(|x| x.unwrap())
303                            .collect::<Vec<_>>()
304                    }),
305                    imatrix: imatrix.clone(),
306                    calibration_file: calibration_file.clone(),
307                    hf_cache_path: hf_cache_path.clone(),
308                    matformer_config_path: matformer_config_path.clone(),
309                    matformer_slice_name: matformer_slice_name.clone(),
310                },
311                MultimodalSpecificConfig {
312                    topology: Topology::from_option_path(topology.clone())?,
313                    write_uqff: write_uqff.clone(),
314                    from_uqff: from_uqff.clone().map(|x| {
315                        x.split(UQFF_MULTI_FILE_DELIMITER)
316                            .map(PathBuf::from_str)
317                            .map(|x| x.unwrap())
318                            .collect::<Vec<_>>()
319                    }),
320                    max_edge,
321                    calibration_file,
322                    imatrix,
323                    hf_cache_path: hf_cache_path.clone(),
324                    matformer_config_path,
325                    matformer_slice_name,
326                    organization: organization.unwrap_or_default(),
327                },
328                EmbeddingSpecificConfig {
329                    topology: Topology::from_option_path(topology)?,
330                    write_uqff,
331                    from_uqff: from_uqff.map(|x| {
332                        x.split(UQFF_MULTI_FILE_DELIMITER)
333                            .map(PathBuf::from_str)
334                            .map(|x| x.unwrap())
335                            .collect::<Vec<_>>()
336                    }),
337                    hf_cache_path: hf_cache_path.clone(),
338                },
339                args.chat_template,
340                tokenizer_json,
341                model_id,
342                args.no_kv_cache,
343                args.jinja_explicit,
344            );
345            let builder = if let Some(ref path) = hf_cache_path {
346                builder.hf_cache_path(path.clone())
347            } else {
348                builder
349            };
350            builder.build()
351        }
352        ModelSelected::MultimodalPlain {
353            model_id,
354            tokenizer_json,
355            arch,
356            dtype: _,
357            topology,
358            write_uqff,
359            from_uqff,
360            max_edge,
361            calibration_file,
362            max_seq_len: _,
363            max_batch_size: _,
364            max_num_images: _,
365            max_image_length: _,
366            hf_cache_path,
367            imatrix,
368            matformer_config_path,
369            matformer_slice_name,
370            organization,
371        } => MultimodalLoaderBuilder::new(
372            MultimodalSpecificConfig {
373                topology: Topology::from_option_path(topology)?,
374                write_uqff,
375                from_uqff: from_uqff.map(|x| {
376                    x.split(UQFF_MULTI_FILE_DELIMITER)
377                        .map(PathBuf::from_str)
378                        .map(|x| x.unwrap())
379                        .collect::<Vec<_>>()
380                }),
381                max_edge,
382                calibration_file,
383                imatrix,
384                hf_cache_path,
385                matformer_config_path,
386                matformer_slice_name,
387                organization: organization.unwrap_or_default(),
388            },
389            args.chat_template,
390            tokenizer_json,
391            Some(model_id),
392            args.jinja_explicit,
393        )
394        .build(arch),
395        ModelSelected::DiffusionPlain {
396            model_id,
397            arch,
398            dtype: _,
399        } => DiffusionLoaderBuilder::new(Some(model_id)).build(arch),
400        ModelSelected::Speech {
401            model_id,
402            dac_model_id,
403            arch,
404            ..
405        } => Box::new(SpeechLoader {
406            model_id,
407            dac_model_id,
408            arch,
409            cfg: None,
410        }),
411        ModelSelected::XLora {
412            model_id,
413            xlora_model_id,
414            order,
415            tokenizer_json,
416            tgt_non_granular_index,
417            arch,
418            dtype: _,
419            topology,
420            write_uqff,
421            from_uqff,
422            max_seq_len: _,
423            max_batch_size: _,
424            hf_cache_path,
425        } => NormalLoaderBuilder::new(
426            NormalSpecificConfig {
427                topology: Topology::from_option_path(topology)?,
428                organization: Default::default(),
429                write_uqff,
430                from_uqff: from_uqff.map(|x| {
431                    x.split(UQFF_MULTI_FILE_DELIMITER)
432                        .map(PathBuf::from_str)
433                        .map(|x| x.unwrap())
434                        .collect::<Vec<_>>()
435                }),
436                imatrix: None,
437                calibration_file: None,
438                hf_cache_path,
439                matformer_config_path: None,
440                matformer_slice_name: None,
441            },
442            args.chat_template,
443            tokenizer_json,
444            model_id,
445            args.no_kv_cache,
446            args.jinja_explicit,
447        )
448        .with_xlora(
449            xlora_model_id,
450            serde_json::from_reader(
451                File::open(order.clone())
452                    .unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
453            )?,
454            args.no_kv_cache,
455            tgt_non_granular_index,
456        )
457        .build(arch)?,
458        ModelSelected::Lora {
459            model_id,
460            tokenizer_json,
461            adapter_model_id,
462            arch,
463            dtype: _,
464            topology,
465            write_uqff,
466            from_uqff,
467            max_seq_len: _,
468            max_batch_size: _,
469            hf_cache_path,
470        } => NormalLoaderBuilder::new(
471            NormalSpecificConfig {
472                topology: Topology::from_option_path(topology)?,
473                organization: Default::default(),
474                write_uqff,
475                from_uqff: from_uqff.map(|x| {
476                    x.split(UQFF_MULTI_FILE_DELIMITER)
477                        .map(PathBuf::from_str)
478                        .map(|x| x.unwrap())
479                        .collect::<Vec<_>>()
480                }),
481                imatrix: None,
482                calibration_file: None,
483                hf_cache_path,
484                matformer_config_path: None,
485                matformer_slice_name: None,
486            },
487            args.chat_template,
488            tokenizer_json,
489            model_id,
490            args.no_kv_cache,
491            args.jinja_explicit,
492        )
493        .with_lora(
494            adapter_model_id
495                .split(MULTI_LORA_DELIMITER)
496                .map(ToString::to_string)
497                .collect(),
498        )
499        .build(arch)?,
500        ModelSelected::GGUF {
501            tok_model_id,
502            quantized_model_id,
503            quantized_filename,
504            topology,
505            ..
506        } => GGUFLoaderBuilder::new(
507            args.chat_template,
508            tok_model_id,
509            quantized_model_id,
510            quantized_filename
511                .split(GGUF_MULTI_FILE_DELIMITER)
512                .map(ToOwned::to_owned)
513                .collect::<Vec<_>>(),
514            GGUFSpecificConfig {
515                topology: Topology::from_option_path(topology)?,
516            },
517            args.no_kv_cache,
518            args.jinja_explicit,
519        )
520        .build(),
521        ModelSelected::XLoraGGUF {
522            tok_model_id,
523            quantized_model_id,
524            quantized_filename,
525            xlora_model_id,
526            order,
527            tgt_non_granular_index,
528            topology,
529            ..
530        } => GGUFLoaderBuilder::new(
531            args.chat_template,
532            tok_model_id,
533            quantized_model_id,
534            quantized_filename
535                .split(GGUF_MULTI_FILE_DELIMITER)
536                .map(ToOwned::to_owned)
537                .collect::<Vec<_>>(),
538            GGUFSpecificConfig {
539                topology: Topology::from_option_path(topology)?,
540            },
541            args.no_kv_cache,
542            args.jinja_explicit,
543        )
544        .with_xlora(
545            xlora_model_id,
546            serde_json::from_reader(
547                File::open(order.clone())
548                    .unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
549            )?,
550            args.no_kv_cache,
551            tgt_non_granular_index,
552        )
553        .build(),
554        ModelSelected::LoraGGUF {
555            tok_model_id,
556            quantized_model_id,
557            quantized_filename,
558            adapters_model_id,
559            order,
560            topology,
561            ..
562        } => GGUFLoaderBuilder::new(
563            args.chat_template,
564            tok_model_id,
565            quantized_model_id,
566            quantized_filename
567                .split(GGUF_MULTI_FILE_DELIMITER)
568                .map(ToOwned::to_owned)
569                .collect::<Vec<_>>(),
570            GGUFSpecificConfig {
571                topology: Topology::from_option_path(topology)?,
572            },
573            args.no_kv_cache,
574            args.jinja_explicit,
575        )
576        .with_lora(
577            adapters_model_id,
578            serde_json::from_reader(
579                File::open(order.clone())
580                    .unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
581            )?,
582        )
583        .build(),
584        ModelSelected::GGML {
585            tok_model_id,
586            tokenizer_json,
587            quantized_model_id,
588            quantized_filename,
589            gqa,
590            topology,
591            ..
592        } => GGMLLoaderBuilder::new(
593            GGMLSpecificConfig {
594                gqa,
595                topology: Topology::from_option_path(topology)?,
596            },
597            args.chat_template,
598            tokenizer_json,
599            Some(tok_model_id),
600            quantized_model_id,
601            quantized_filename,
602            args.no_kv_cache,
603            args.jinja_explicit,
604        )
605        .build(),
606        ModelSelected::XLoraGGML {
607            tok_model_id,
608            tokenizer_json,
609            quantized_model_id,
610            quantized_filename,
611            xlora_model_id,
612            order,
613            tgt_non_granular_index,
614            gqa,
615            topology,
616            ..
617        } => GGMLLoaderBuilder::new(
618            GGMLSpecificConfig {
619                gqa,
620                topology: Topology::from_option_path(topology)?,
621            },
622            args.chat_template,
623            tokenizer_json,
624            tok_model_id,
625            quantized_model_id,
626            quantized_filename,
627            args.no_kv_cache,
628            args.jinja_explicit,
629        )
630        .with_xlora(
631            xlora_model_id,
632            serde_json::from_reader(
633                File::open(order.clone())
634                    .unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
635            )?,
636            args.no_kv_cache,
637            tgt_non_granular_index,
638        )
639        .build(),
640        ModelSelected::LoraGGML {
641            tok_model_id,
642            tokenizer_json,
643            quantized_model_id,
644            quantized_filename,
645            adapters_model_id,
646            order,
647            gqa,
648            topology,
649            ..
650        } => GGMLLoaderBuilder::new(
651            GGMLSpecificConfig {
652                gqa,
653                topology: Topology::from_option_path(topology)?,
654            },
655            args.chat_template,
656            tokenizer_json,
657            tok_model_id,
658            quantized_model_id,
659            quantized_filename,
660            args.no_kv_cache,
661            args.jinja_explicit,
662        )
663        .with_lora(
664            adapters_model_id,
665            serde_json::from_reader(
666                File::open(order.clone())
667                    .unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
668            )?,
669        )
670        .build(),
671        ModelSelected::Embedding {
672            model_id,
673            tokenizer_json,
674            arch,
675            dtype: _,
676            topology,
677            write_uqff,
678            from_uqff,
679            hf_cache_path,
680        } => EmbeddingLoaderBuilder::new(
681            EmbeddingSpecificConfig {
682                topology: Topology::from_option_path(topology)?,
683                write_uqff,
684                from_uqff: from_uqff.map(|x| {
685                    x.split(UQFF_MULTI_FILE_DELIMITER)
686                        .map(PathBuf::from_str)
687                        .map(|x| x.unwrap())
688                        .collect::<Vec<_>>()
689                }),
690                hf_cache_path,
691            },
692            tokenizer_json,
693            Some(model_id),
694        )
695        .build(arch),
696        ModelSelected::MultiModel { .. } => {
697            anyhow::bail!("MultiModel variant should not be used in model loading functions")
698        }
699    };
700    Ok(loader)
701}