1use std::{
2 fs::{self, File},
3 path::PathBuf,
4 str::FromStr,
5};
6
7use hanzo_quant::MULTI_LORA_DELIMITER;
8
9use crate::{
10 get_toml_selected_model_dtype,
11 pipeline::{
12 AutoLoaderBuilder, DiffusionLoaderBuilder, GGMLLoaderBuilder, GGMLSpecificConfig,
13 GGUFLoaderBuilder, GGUFSpecificConfig, MultimodalLoaderBuilder, MultimodalSpecificConfig,
14 NormalLoaderBuilder, NormalSpecificConfig,
15 },
16 toml_selector::get_toml_selected_model_device_map_params,
17 AutoDeviceMapParams, EmbeddingLoaderBuilder, EmbeddingSpecificConfig, Loader, ModelDType,
18 ModelSelected, SpeechLoader, TomlLoaderArgs, TomlSelector, Topology, GGUF_MULTI_FILE_DELIMITER,
19 UQFF_MULTI_FILE_DELIMITER,
20};
21
22pub struct LoaderBuilder {
24 model: ModelSelected,
25 no_kv_cache: bool,
26 chat_template: Option<String>,
27 jinja_explicit: Option<String>,
28}
29
30impl LoaderBuilder {
31 pub fn new(model: ModelSelected) -> Self {
32 Self {
33 model,
34 no_kv_cache: false,
35 chat_template: None,
36 jinja_explicit: None,
37 }
38 }
39
40 pub fn with_no_kv_cache(mut self, no_kv_cache: bool) -> Self {
41 self.no_kv_cache = no_kv_cache;
42 self
43 }
44 pub fn with_chat_template(mut self, chat_template: Option<String>) -> Self {
45 self.chat_template = chat_template;
46 self
47 }
48 pub fn with_jinja_explicit(mut self, jinja_explicit: Option<String>) -> Self {
49 self.jinja_explicit = jinja_explicit;
50 self
51 }
52
53 pub fn build(self) -> anyhow::Result<Box<dyn Loader>> {
54 loader_from_model_selected(self)
55 }
56}
57
58pub fn get_tgt_non_granular_index(model: &ModelSelected) -> Option<usize> {
59 match model {
60 ModelSelected::Plain { .. }
61 | ModelSelected::Run { .. }
62 | ModelSelected::Lora { .. }
63 | ModelSelected::GGUF { .. }
64 | ModelSelected::LoraGGUF { .. }
65 | ModelSelected::GGML { .. }
66 | ModelSelected::LoraGGML { .. }
67 | ModelSelected::Toml { .. }
68 | ModelSelected::MultimodalPlain { .. }
69 | ModelSelected::DiffusionPlain { .. }
70 | ModelSelected::Speech { .. }
71 | ModelSelected::Embedding { .. } => None,
72 ModelSelected::XLora {
73 tgt_non_granular_index,
74 ..
75 }
76 | ModelSelected::XLoraGGUF {
77 tgt_non_granular_index,
78 ..
79 }
80 | ModelSelected::XLoraGGML {
81 tgt_non_granular_index,
82 ..
83 } => *tgt_non_granular_index,
84 ModelSelected::MultiModel { .. } => {
85 panic!("MultiModel variant should not be used in model loading functions")
86 }
87 }
88}
89
90pub fn get_model_dtype(model: &ModelSelected) -> anyhow::Result<ModelDType> {
91 match model {
92 ModelSelected::Plain { dtype, .. }
93 | ModelSelected::Lora { dtype, .. }
94 | ModelSelected::XLora { dtype, .. }
95 | ModelSelected::MultimodalPlain { dtype, .. }
96 | ModelSelected::DiffusionPlain { dtype, .. }
97 | ModelSelected::GGML { dtype, .. }
98 | ModelSelected::GGUF { dtype, .. }
99 | ModelSelected::XLoraGGUF { dtype, .. }
100 | ModelSelected::XLoraGGML { dtype, .. }
101 | ModelSelected::LoraGGUF { dtype, .. }
102 | ModelSelected::LoraGGML { dtype, .. }
103 | ModelSelected::Run { dtype, .. }
104 | ModelSelected::Speech { dtype, .. }
105 | ModelSelected::Embedding { dtype, .. } => Ok(*dtype),
106 ModelSelected::Toml { file } => {
107 let selector: TomlSelector = toml::from_str(
108 &fs::read_to_string(file.clone())
109 .unwrap_or_else(|_| panic!("Could not load toml selector file at {file}")),
110 )?;
111 Ok(get_toml_selected_model_dtype(&selector))
112 }
113 ModelSelected::MultiModel { .. } => {
114 anyhow::bail!("MultiModel variant should not be used in model loading functions")
115 }
116 }
117}
118
119pub fn get_auto_device_map_params(model: &ModelSelected) -> anyhow::Result<AutoDeviceMapParams> {
120 match model {
121 ModelSelected::Plain {
122 max_seq_len,
123 max_batch_size,
124 ..
125 }
126 | ModelSelected::Lora {
127 max_seq_len,
128 max_batch_size,
129 ..
130 }
131 | ModelSelected::XLora {
132 max_seq_len,
133 max_batch_size,
134 ..
135 }
136 | ModelSelected::GGML {
137 max_seq_len,
138 max_batch_size,
139 ..
140 }
141 | ModelSelected::GGUF {
142 max_seq_len,
143 max_batch_size,
144 ..
145 }
146 | ModelSelected::XLoraGGUF {
147 max_seq_len,
148 max_batch_size,
149 ..
150 }
151 | ModelSelected::XLoraGGML {
152 max_seq_len,
153 max_batch_size,
154 ..
155 }
156 | ModelSelected::LoraGGUF {
157 max_seq_len,
158 max_batch_size,
159 ..
160 }
161 | ModelSelected::LoraGGML {
162 max_seq_len,
163 max_batch_size,
164 ..
165 } => Ok(AutoDeviceMapParams::Text {
166 max_seq_len: *max_seq_len,
167 max_batch_size: *max_batch_size,
168 }),
169 ModelSelected::Run {
170 max_seq_len,
171 max_batch_size,
172 max_image_length,
173 max_num_images,
174 ..
175 } => {
176 if max_num_images.is_some() || max_image_length.is_some() {
177 let max_image_length =
178 max_image_length.unwrap_or(AutoDeviceMapParams::DEFAULT_MAX_IMAGE_LENGTH);
179 Ok(AutoDeviceMapParams::Multimodal {
180 max_seq_len: *max_seq_len,
181 max_batch_size: *max_batch_size,
182 max_image_shape: (max_image_length, max_image_length),
183 max_num_images: max_num_images
184 .unwrap_or(AutoDeviceMapParams::DEFAULT_MAX_NUM_IMAGES),
185 })
186 } else {
187 Ok(AutoDeviceMapParams::Text {
188 max_seq_len: *max_seq_len,
189 max_batch_size: *max_batch_size,
190 })
191 }
192 }
193 ModelSelected::MultimodalPlain {
194 max_seq_len,
195 max_batch_size,
196 max_image_length,
197 max_num_images,
198 ..
199 } => Ok(AutoDeviceMapParams::Multimodal {
200 max_seq_len: *max_seq_len,
201 max_batch_size: *max_batch_size,
202 max_image_shape: (*max_image_length, *max_image_length),
203 max_num_images: *max_num_images,
204 }),
205 ModelSelected::DiffusionPlain { .. }
206 | ModelSelected::Speech { .. }
207 | ModelSelected::Embedding { .. } => Ok(AutoDeviceMapParams::default_text()),
208 ModelSelected::Toml { file } => {
209 let selector: TomlSelector = toml::from_str(
210 &fs::read_to_string(file.clone())
211 .unwrap_or_else(|_| panic!("Could not load toml selector file at {file}")),
212 )?;
213 get_toml_selected_model_device_map_params(&selector)
214 }
215 ModelSelected::MultiModel { .. } => {
216 anyhow::bail!("MultiModel variant should not be used in model loading functions")
217 }
218 }
219}
220
221fn loader_from_model_selected(args: LoaderBuilder) -> anyhow::Result<Box<dyn Loader>> {
222 let loader: Box<dyn Loader> = match args.model {
223 ModelSelected::Toml { file } => {
224 let selector: TomlSelector = toml::from_str(
225 &fs::read_to_string(file.clone())
226 .unwrap_or_else(|_| panic!("Could not load toml selector file at {file}")),
227 )?;
228 let args = TomlLoaderArgs {
229 chat_template: args.chat_template,
230 no_kv_cache: args.no_kv_cache,
231 jinja_explicit: args.jinja_explicit,
232 };
233 (selector, args).try_into()?
234 }
235 ModelSelected::Plain {
236 model_id,
237 tokenizer_json,
238 arch,
239 dtype: _,
240 topology,
241 organization,
242 write_uqff,
243 from_uqff,
244 imatrix,
245 calibration_file,
246 max_seq_len: _,
247 max_batch_size: _,
248 hf_cache_path,
249 matformer_config_path,
250 matformer_slice_name,
251 } => NormalLoaderBuilder::new(
252 NormalSpecificConfig {
253 topology: Topology::from_option_path(topology)?,
254 organization: organization.unwrap_or_default(),
255 write_uqff,
256 from_uqff: from_uqff.map(|x| {
257 x.split(UQFF_MULTI_FILE_DELIMITER)
258 .map(PathBuf::from_str)
259 .map(|x| x.unwrap())
260 .collect::<Vec<_>>()
261 }),
262 imatrix,
263 calibration_file,
264 hf_cache_path,
265 matformer_config_path,
266 matformer_slice_name,
267 },
268 args.chat_template,
269 tokenizer_json,
270 Some(model_id),
271 args.no_kv_cache,
272 args.jinja_explicit,
273 )
274 .build(arch)?,
275 ModelSelected::Run {
276 model_id,
277 tokenizer_json,
278 dtype: _,
279 topology,
280 organization,
281 write_uqff,
282 from_uqff,
283 imatrix,
284 calibration_file,
285 max_edge,
286 max_seq_len: _,
287 max_batch_size: _,
288 max_num_images: _,
289 max_image_length: _,
290 hf_cache_path,
291 matformer_config_path,
292 matformer_slice_name,
293 } => {
294 let builder = AutoLoaderBuilder::new(
295 NormalSpecificConfig {
296 topology: Topology::from_option_path(topology.clone())?,
297 organization: organization.unwrap_or_default(),
298 write_uqff: write_uqff.clone(),
299 from_uqff: from_uqff.clone().map(|x| {
300 x.split(UQFF_MULTI_FILE_DELIMITER)
301 .map(PathBuf::from_str)
302 .map(|x| x.unwrap())
303 .collect::<Vec<_>>()
304 }),
305 imatrix: imatrix.clone(),
306 calibration_file: calibration_file.clone(),
307 hf_cache_path: hf_cache_path.clone(),
308 matformer_config_path: matformer_config_path.clone(),
309 matformer_slice_name: matformer_slice_name.clone(),
310 },
311 MultimodalSpecificConfig {
312 topology: Topology::from_option_path(topology.clone())?,
313 write_uqff: write_uqff.clone(),
314 from_uqff: from_uqff.clone().map(|x| {
315 x.split(UQFF_MULTI_FILE_DELIMITER)
316 .map(PathBuf::from_str)
317 .map(|x| x.unwrap())
318 .collect::<Vec<_>>()
319 }),
320 max_edge,
321 calibration_file,
322 imatrix,
323 hf_cache_path: hf_cache_path.clone(),
324 matformer_config_path,
325 matformer_slice_name,
326 organization: organization.unwrap_or_default(),
327 },
328 EmbeddingSpecificConfig {
329 topology: Topology::from_option_path(topology)?,
330 write_uqff,
331 from_uqff: from_uqff.map(|x| {
332 x.split(UQFF_MULTI_FILE_DELIMITER)
333 .map(PathBuf::from_str)
334 .map(|x| x.unwrap())
335 .collect::<Vec<_>>()
336 }),
337 hf_cache_path: hf_cache_path.clone(),
338 },
339 args.chat_template,
340 tokenizer_json,
341 model_id,
342 args.no_kv_cache,
343 args.jinja_explicit,
344 );
345 let builder = if let Some(ref path) = hf_cache_path {
346 builder.hf_cache_path(path.clone())
347 } else {
348 builder
349 };
350 builder.build()
351 }
352 ModelSelected::MultimodalPlain {
353 model_id,
354 tokenizer_json,
355 arch,
356 dtype: _,
357 topology,
358 write_uqff,
359 from_uqff,
360 max_edge,
361 calibration_file,
362 max_seq_len: _,
363 max_batch_size: _,
364 max_num_images: _,
365 max_image_length: _,
366 hf_cache_path,
367 imatrix,
368 matformer_config_path,
369 matformer_slice_name,
370 organization,
371 } => MultimodalLoaderBuilder::new(
372 MultimodalSpecificConfig {
373 topology: Topology::from_option_path(topology)?,
374 write_uqff,
375 from_uqff: from_uqff.map(|x| {
376 x.split(UQFF_MULTI_FILE_DELIMITER)
377 .map(PathBuf::from_str)
378 .map(|x| x.unwrap())
379 .collect::<Vec<_>>()
380 }),
381 max_edge,
382 calibration_file,
383 imatrix,
384 hf_cache_path,
385 matformer_config_path,
386 matformer_slice_name,
387 organization: organization.unwrap_or_default(),
388 },
389 args.chat_template,
390 tokenizer_json,
391 Some(model_id),
392 args.jinja_explicit,
393 )
394 .build(arch),
395 ModelSelected::DiffusionPlain {
396 model_id,
397 arch,
398 dtype: _,
399 } => DiffusionLoaderBuilder::new(Some(model_id)).build(arch),
400 ModelSelected::Speech {
401 model_id,
402 dac_model_id,
403 arch,
404 ..
405 } => Box::new(SpeechLoader {
406 model_id,
407 dac_model_id,
408 arch,
409 cfg: None,
410 }),
411 ModelSelected::XLora {
412 model_id,
413 xlora_model_id,
414 order,
415 tokenizer_json,
416 tgt_non_granular_index,
417 arch,
418 dtype: _,
419 topology,
420 write_uqff,
421 from_uqff,
422 max_seq_len: _,
423 max_batch_size: _,
424 hf_cache_path,
425 } => NormalLoaderBuilder::new(
426 NormalSpecificConfig {
427 topology: Topology::from_option_path(topology)?,
428 organization: Default::default(),
429 write_uqff,
430 from_uqff: from_uqff.map(|x| {
431 x.split(UQFF_MULTI_FILE_DELIMITER)
432 .map(PathBuf::from_str)
433 .map(|x| x.unwrap())
434 .collect::<Vec<_>>()
435 }),
436 imatrix: None,
437 calibration_file: None,
438 hf_cache_path,
439 matformer_config_path: None,
440 matformer_slice_name: None,
441 },
442 args.chat_template,
443 tokenizer_json,
444 model_id,
445 args.no_kv_cache,
446 args.jinja_explicit,
447 )
448 .with_xlora(
449 xlora_model_id,
450 serde_json::from_reader(
451 File::open(order.clone())
452 .unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
453 )?,
454 args.no_kv_cache,
455 tgt_non_granular_index,
456 )
457 .build(arch)?,
458 ModelSelected::Lora {
459 model_id,
460 tokenizer_json,
461 adapter_model_id,
462 arch,
463 dtype: _,
464 topology,
465 write_uqff,
466 from_uqff,
467 max_seq_len: _,
468 max_batch_size: _,
469 hf_cache_path,
470 } => NormalLoaderBuilder::new(
471 NormalSpecificConfig {
472 topology: Topology::from_option_path(topology)?,
473 organization: Default::default(),
474 write_uqff,
475 from_uqff: from_uqff.map(|x| {
476 x.split(UQFF_MULTI_FILE_DELIMITER)
477 .map(PathBuf::from_str)
478 .map(|x| x.unwrap())
479 .collect::<Vec<_>>()
480 }),
481 imatrix: None,
482 calibration_file: None,
483 hf_cache_path,
484 matformer_config_path: None,
485 matformer_slice_name: None,
486 },
487 args.chat_template,
488 tokenizer_json,
489 model_id,
490 args.no_kv_cache,
491 args.jinja_explicit,
492 )
493 .with_lora(
494 adapter_model_id
495 .split(MULTI_LORA_DELIMITER)
496 .map(ToString::to_string)
497 .collect(),
498 )
499 .build(arch)?,
500 ModelSelected::GGUF {
501 tok_model_id,
502 quantized_model_id,
503 quantized_filename,
504 topology,
505 ..
506 } => GGUFLoaderBuilder::new(
507 args.chat_template,
508 tok_model_id,
509 quantized_model_id,
510 quantized_filename
511 .split(GGUF_MULTI_FILE_DELIMITER)
512 .map(ToOwned::to_owned)
513 .collect::<Vec<_>>(),
514 GGUFSpecificConfig {
515 topology: Topology::from_option_path(topology)?,
516 },
517 args.no_kv_cache,
518 args.jinja_explicit,
519 )
520 .build(),
521 ModelSelected::XLoraGGUF {
522 tok_model_id,
523 quantized_model_id,
524 quantized_filename,
525 xlora_model_id,
526 order,
527 tgt_non_granular_index,
528 topology,
529 ..
530 } => GGUFLoaderBuilder::new(
531 args.chat_template,
532 tok_model_id,
533 quantized_model_id,
534 quantized_filename
535 .split(GGUF_MULTI_FILE_DELIMITER)
536 .map(ToOwned::to_owned)
537 .collect::<Vec<_>>(),
538 GGUFSpecificConfig {
539 topology: Topology::from_option_path(topology)?,
540 },
541 args.no_kv_cache,
542 args.jinja_explicit,
543 )
544 .with_xlora(
545 xlora_model_id,
546 serde_json::from_reader(
547 File::open(order.clone())
548 .unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
549 )?,
550 args.no_kv_cache,
551 tgt_non_granular_index,
552 )
553 .build(),
554 ModelSelected::LoraGGUF {
555 tok_model_id,
556 quantized_model_id,
557 quantized_filename,
558 adapters_model_id,
559 order,
560 topology,
561 ..
562 } => GGUFLoaderBuilder::new(
563 args.chat_template,
564 tok_model_id,
565 quantized_model_id,
566 quantized_filename
567 .split(GGUF_MULTI_FILE_DELIMITER)
568 .map(ToOwned::to_owned)
569 .collect::<Vec<_>>(),
570 GGUFSpecificConfig {
571 topology: Topology::from_option_path(topology)?,
572 },
573 args.no_kv_cache,
574 args.jinja_explicit,
575 )
576 .with_lora(
577 adapters_model_id,
578 serde_json::from_reader(
579 File::open(order.clone())
580 .unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
581 )?,
582 )
583 .build(),
584 ModelSelected::GGML {
585 tok_model_id,
586 tokenizer_json,
587 quantized_model_id,
588 quantized_filename,
589 gqa,
590 topology,
591 ..
592 } => GGMLLoaderBuilder::new(
593 GGMLSpecificConfig {
594 gqa,
595 topology: Topology::from_option_path(topology)?,
596 },
597 args.chat_template,
598 tokenizer_json,
599 Some(tok_model_id),
600 quantized_model_id,
601 quantized_filename,
602 args.no_kv_cache,
603 args.jinja_explicit,
604 )
605 .build(),
606 ModelSelected::XLoraGGML {
607 tok_model_id,
608 tokenizer_json,
609 quantized_model_id,
610 quantized_filename,
611 xlora_model_id,
612 order,
613 tgt_non_granular_index,
614 gqa,
615 topology,
616 ..
617 } => GGMLLoaderBuilder::new(
618 GGMLSpecificConfig {
619 gqa,
620 topology: Topology::from_option_path(topology)?,
621 },
622 args.chat_template,
623 tokenizer_json,
624 tok_model_id,
625 quantized_model_id,
626 quantized_filename,
627 args.no_kv_cache,
628 args.jinja_explicit,
629 )
630 .with_xlora(
631 xlora_model_id,
632 serde_json::from_reader(
633 File::open(order.clone())
634 .unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
635 )?,
636 args.no_kv_cache,
637 tgt_non_granular_index,
638 )
639 .build(),
640 ModelSelected::LoraGGML {
641 tok_model_id,
642 tokenizer_json,
643 quantized_model_id,
644 quantized_filename,
645 adapters_model_id,
646 order,
647 gqa,
648 topology,
649 ..
650 } => GGMLLoaderBuilder::new(
651 GGMLSpecificConfig {
652 gqa,
653 topology: Topology::from_option_path(topology)?,
654 },
655 args.chat_template,
656 tokenizer_json,
657 tok_model_id,
658 quantized_model_id,
659 quantized_filename,
660 args.no_kv_cache,
661 args.jinja_explicit,
662 )
663 .with_lora(
664 adapters_model_id,
665 serde_json::from_reader(
666 File::open(order.clone())
667 .unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
668 )?,
669 )
670 .build(),
671 ModelSelected::Embedding {
672 model_id,
673 tokenizer_json,
674 arch,
675 dtype: _,
676 topology,
677 write_uqff,
678 from_uqff,
679 hf_cache_path,
680 } => EmbeddingLoaderBuilder::new(
681 EmbeddingSpecificConfig {
682 topology: Topology::from_option_path(topology)?,
683 write_uqff,
684 from_uqff: from_uqff.map(|x| {
685 x.split(UQFF_MULTI_FILE_DELIMITER)
686 .map(PathBuf::from_str)
687 .map(|x| x.unwrap())
688 .collect::<Vec<_>>()
689 }),
690 hf_cache_path,
691 },
692 tokenizer_json,
693 Some(model_id),
694 )
695 .build(arch),
696 ModelSelected::MultiModel { .. } => {
697 anyhow::bail!("MultiModel variant should not be used in model loading functions")
698 }
699 };
700 Ok(loader)
701}