1use std::{fs::File, path::PathBuf, str::FromStr};
2
3use hanzo_quant::MULTI_LORA_DELIMITER;
4use serde::Deserialize;
5
6use crate::{
7 amoe::AnyMoeConfig,
8 pipeline::{EmbeddingLoaderType, IsqOrganization},
9 AnyMoeLoader, AutoDeviceMapParams, EmbeddingLoaderBuilder, EmbeddingSpecificConfig,
10 GGMLLoaderBuilder, GGMLSpecificConfig, GGUFLoaderBuilder, GGUFSpecificConfig, Loader,
11 ModelDType, MultimodalLoaderBuilder, MultimodalLoaderType, MultimodalSpecificConfig,
12 NormalLoaderBuilder, NormalLoaderType, NormalSpecificConfig, Topology,
13 GGUF_MULTI_FILE_DELIMITER, UQFF_MULTI_FILE_DELIMITER,
14};
15
16fn default_one() -> usize {
17 1
18}
19
20fn default_dtype() -> ModelDType {
21 ModelDType::Auto
22}
23
24fn default_empty_vec_usize() -> Vec<usize> {
25 Vec::new()
26}
27
28fn default_max_seq_len() -> usize {
29 AutoDeviceMapParams::DEFAULT_MAX_SEQ_LEN
30}
31
32fn default_max_batch_size() -> usize {
33 AutoDeviceMapParams::DEFAULT_MAX_BATCH_SIZE
34}
35
36fn default_max_num_images() -> usize {
37 AutoDeviceMapParams::DEFAULT_MAX_NUM_IMAGES
38}
39
40fn default_max_image_length() -> usize {
41 AutoDeviceMapParams::DEFAULT_MAX_IMAGE_LENGTH
42}
43
44#[derive(Debug, Deserialize)]
45#[serde(untagged)]
46pub enum TomlModelSelected {
47 Plain {
49 model_id: String,
51
52 arch: Option<NormalLoaderType>,
54
55 #[serde(default = "default_dtype")]
57 dtype: ModelDType,
58
59 topology: Option<String>,
61
62 organization: Option<IsqOrganization>,
64
65 write_uqff: Option<PathBuf>,
67
68 from_uqff: Option<String>,
70
71 imatrix: Option<PathBuf>,
74
75 calibration_file: Option<PathBuf>,
78
79 #[serde(default = "default_max_seq_len")]
81 max_seq_len: usize,
82
83 #[serde(default = "default_max_batch_size")]
85 max_batch_size: usize,
86
87 hf_cache_path: Option<PathBuf>,
89 },
90
91 XLora {
93 model_id: Option<String>,
95
96 xlora_model_id: String,
98
99 order: String,
101
102 tgt_non_granular_index: Option<usize>,
105
106 arch: Option<NormalLoaderType>,
108
109 #[serde(default = "default_dtype")]
111 dtype: ModelDType,
112
113 topology: Option<String>,
115
116 write_uqff: Option<PathBuf>,
118
119 from_uqff: Option<String>,
121
122 #[serde(default = "default_max_seq_len")]
124 max_seq_len: usize,
125
126 #[serde(default = "default_max_batch_size")]
128 max_batch_size: usize,
129
130 hf_cache_path: Option<PathBuf>,
132 },
133
134 Lora {
136 model_id: Option<String>,
138
139 adapter_model_ids: String,
141
142 arch: Option<NormalLoaderType>,
144
145 #[serde(default = "default_dtype")]
147 dtype: ModelDType,
148
149 topology: Option<String>,
151
152 write_uqff: Option<PathBuf>,
154
155 from_uqff: Option<String>,
157
158 #[serde(default = "default_max_seq_len")]
160 max_seq_len: usize,
161
162 #[serde(default = "default_max_batch_size")]
164 max_batch_size: usize,
165
166 hf_cache_path: Option<PathBuf>,
168 },
169
170 #[allow(clippy::upper_case_acronyms)]
172 GGUF {
173 tok_model_id: String,
177
178 quantized_model_id: String,
181
182 quantized_filename: String,
185
186 #[serde(default = "default_dtype")]
188 dtype: ModelDType,
189
190 topology: Option<String>,
192
193 #[serde(default = "default_max_seq_len")]
195 max_seq_len: usize,
196
197 #[serde(default = "default_max_batch_size")]
199 max_batch_size: usize,
200 },
201
202 XLoraGGUF {
204 tok_model_id: Option<String>,
208
209 quantized_model_id: String,
212
213 quantized_filename: String,
216
217 xlora_model_id: String,
219
220 order: String,
222
223 tgt_non_granular_index: Option<usize>,
226
227 #[serde(default = "default_dtype")]
229 dtype: ModelDType,
230
231 topology: Option<String>,
233
234 #[serde(default = "default_max_seq_len")]
236 max_seq_len: usize,
237
238 #[serde(default = "default_max_batch_size")]
240 max_batch_size: usize,
241 },
242
243 LoraGGUF {
245 tok_model_id: Option<String>,
249
250 quantized_model_id: String,
253
254 quantized_filename: String,
257
258 adapters_model_id: String,
260
261 order: String,
263
264 #[serde(default = "default_dtype")]
266 dtype: ModelDType,
267
268 topology: Option<String>,
270
271 #[serde(default = "default_max_seq_len")]
273 max_seq_len: usize,
274
275 #[serde(default = "default_max_batch_size")]
277 max_batch_size: usize,
278 },
279
280 #[allow(clippy::upper_case_acronyms)]
282 GGML {
283 tok_model_id: String,
285
286 quantized_model_id: String,
289
290 quantized_filename: String,
292
293 #[serde(default = "default_one")]
295 gqa: usize,
296
297 #[serde(default = "default_dtype")]
299 dtype: ModelDType,
300
301 topology: Option<String>,
303
304 #[serde(default = "default_max_seq_len")]
306 max_seq_len: usize,
307
308 #[serde(default = "default_max_batch_size")]
310 max_batch_size: usize,
311 },
312
313 XLoraGGML {
315 tok_model_id: Option<String>,
317
318 quantized_model_id: String,
321
322 quantized_filename: String,
324
325 xlora_model_id: String,
327
328 order: String,
330
331 tgt_non_granular_index: Option<usize>,
334
335 #[serde(default = "default_one")]
337 gqa: usize,
338
339 #[serde(default = "default_dtype")]
341 dtype: ModelDType,
342
343 topology: Option<String>,
345
346 #[serde(default = "default_max_seq_len")]
348 max_seq_len: usize,
349
350 #[serde(default = "default_max_batch_size")]
352 max_batch_size: usize,
353 },
354
355 LoraGGML {
357 tok_model_id: Option<String>,
359
360 quantized_model_id: String,
363
364 quantized_filename: String,
366
367 adapters_model_id: String,
369
370 order: String,
372
373 #[serde(default = "default_one")]
375 gqa: usize,
376
377 #[serde(default = "default_dtype")]
379 dtype: ModelDType,
380
381 topology: Option<String>,
383
384 #[serde(default = "default_max_seq_len")]
386 max_seq_len: usize,
387
388 #[serde(default = "default_max_batch_size")]
390 max_batch_size: usize,
391 },
392
393 MultimodalPlain {
395 model_id: String,
397
398 arch: Option<MultimodalLoaderType>,
400
401 #[serde(default = "default_dtype")]
403 dtype: ModelDType,
404
405 topology: Option<String>,
407
408 write_uqff: Option<PathBuf>,
410
411 from_uqff: Option<String>,
413
414 max_edge: Option<u32>,
417
418 calibration_file: Option<PathBuf>,
420
421 imatrix: Option<PathBuf>,
423
424 #[serde(default = "default_max_seq_len")]
426 max_seq_len: usize,
427
428 #[serde(default = "default_max_batch_size")]
430 max_batch_size: usize,
431
432 #[serde(default = "default_max_num_images")]
434 max_num_images: usize,
435
436 #[serde(default = "default_max_image_length")]
439 max_image_length: usize,
440
441 hf_cache_path: Option<PathBuf>,
443
444 organization: Option<IsqOrganization>,
446 },
447
448 Embedding {
450 model_id: String,
452
453 #[serde(default)]
455 tokenizer_json: Option<String>,
456
457 #[serde(default)]
459 arch: Option<EmbeddingLoaderType>,
460
461 #[serde(default = "default_dtype")]
463 dtype: ModelDType,
464
465 #[serde(default)]
467 topology: Option<String>,
468
469 #[serde(default)]
471 write_uqff: Option<PathBuf>,
472
473 #[serde(default)]
475 from_uqff: Option<String>,
476
477 #[serde(default)]
479 hf_cache_path: Option<PathBuf>,
480 },
481}
482
483#[derive(Deserialize)]
484pub struct AnyMoeTomlModelSelected {
485 config: AnyMoeConfig,
487
488 dataset_json: String,
490
491 prefix: String,
493
494 mlp: String,
496
497 model_ids: Vec<String>,
499
500 #[serde(default = "default_empty_vec_usize")]
502 layers: Vec<usize>,
503}
504
505#[derive(Deserialize)]
506pub struct TomlSelector {
507 tokenizer_json: Option<String>,
509
510 model: TomlModelSelected,
512
513 #[serde(default)]
516 speculative: Option<serde::de::IgnoredAny>,
517
518 anymoe: Option<AnyMoeTomlModelSelected>,
520}
521
522#[derive(Clone)]
523struct TomlLoaderInnerParams {
524 chat_template: Option<String>,
525 no_kv_cache: bool,
526 tokenizer_json: Option<String>,
527 jinja_explicit: Option<String>,
528}
529
530pub struct TomlLoaderArgs {
531 pub chat_template: Option<String>,
532 pub no_kv_cache: bool,
533 pub jinja_explicit: Option<String>,
534}
535
536pub fn get_toml_selected_model_dtype(model: &TomlSelector) -> ModelDType {
537 match model.model {
538 TomlModelSelected::Plain { dtype, .. }
539 | TomlModelSelected::Lora { dtype, .. }
540 | TomlModelSelected::XLora { dtype, .. }
541 | TomlModelSelected::MultimodalPlain { dtype, .. }
542 | TomlModelSelected::GGUF { dtype, .. }
543 | TomlModelSelected::GGML { dtype, .. }
544 | TomlModelSelected::XLoraGGUF { dtype, .. }
545 | TomlModelSelected::XLoraGGML { dtype, .. }
546 | TomlModelSelected::LoraGGUF { dtype, .. }
547 | TomlModelSelected::LoraGGML { dtype, .. }
548 | TomlModelSelected::Embedding { dtype, .. } => dtype,
549 }
550}
551
552pub fn get_toml_selected_model_device_map_params(
553 model: &TomlSelector,
554) -> anyhow::Result<AutoDeviceMapParams> {
555 match model.model {
556 TomlModelSelected::Plain {
557 max_seq_len,
558 max_batch_size,
559 ..
560 }
561 | TomlModelSelected::Lora {
562 max_seq_len,
563 max_batch_size,
564 ..
565 }
566 | TomlModelSelected::XLora {
567 max_seq_len,
568 max_batch_size,
569 ..
570 }
571 | TomlModelSelected::GGML {
572 max_seq_len,
573 max_batch_size,
574 ..
575 }
576 | TomlModelSelected::GGUF {
577 max_seq_len,
578 max_batch_size,
579 ..
580 }
581 | TomlModelSelected::XLoraGGUF {
582 max_seq_len,
583 max_batch_size,
584 ..
585 }
586 | TomlModelSelected::XLoraGGML {
587 max_seq_len,
588 max_batch_size,
589 ..
590 }
591 | TomlModelSelected::LoraGGUF {
592 max_seq_len,
593 max_batch_size,
594 ..
595 }
596 | TomlModelSelected::LoraGGML {
597 max_seq_len,
598 max_batch_size,
599 ..
600 } => Ok(AutoDeviceMapParams::Text {
601 max_seq_len,
602 max_batch_size,
603 }),
604 TomlModelSelected::Embedding { .. } => Ok(AutoDeviceMapParams::default_text()),
605 TomlModelSelected::MultimodalPlain {
606 max_seq_len,
607 max_batch_size,
608 max_image_length,
609 max_num_images,
610 ..
611 } => Ok(AutoDeviceMapParams::Multimodal {
612 max_seq_len,
613 max_batch_size,
614 max_image_shape: (max_image_length, max_image_length),
615 max_num_images,
616 }),
617 }
618}
619
620fn loader_from_selected(
621 args: TomlLoaderInnerParams,
622 model: TomlModelSelected,
623) -> anyhow::Result<Box<dyn Loader>> {
624 let loader: Box<dyn Loader> = match model {
625 TomlModelSelected::Plain {
626 model_id,
627 arch,
628 dtype: _,
629 topology,
630 organization,
631 write_uqff,
632 from_uqff,
633 imatrix,
634 calibration_file,
635 max_seq_len: _,
636 max_batch_size: _,
637 hf_cache_path,
638 } => NormalLoaderBuilder::new(
639 NormalSpecificConfig {
640 topology: Topology::from_option_path(topology)?,
641 organization: organization.unwrap_or_default(),
642 write_uqff,
643 from_uqff: from_uqff.map(|x| {
644 x.split(UQFF_MULTI_FILE_DELIMITER)
645 .map(PathBuf::from_str)
646 .map(|x| x.unwrap())
647 .collect::<Vec<_>>()
648 }),
649 imatrix,
650 calibration_file,
651 hf_cache_path,
652 matformer_config_path: None,
653 matformer_slice_name: None,
654 },
655 args.chat_template,
656 args.tokenizer_json,
657 Some(model_id),
658 args.no_kv_cache,
659 args.jinja_explicit,
660 )
661 .build(arch)?,
662 TomlModelSelected::XLora {
663 model_id,
664 xlora_model_id,
665 order,
666 tgt_non_granular_index,
667 arch,
668 dtype: _,
669 topology,
670 write_uqff,
671 from_uqff,
672 max_seq_len: _,
673 max_batch_size: _,
674 hf_cache_path,
675 } => NormalLoaderBuilder::new(
676 NormalSpecificConfig {
677 topology: Topology::from_option_path(topology)?,
678 organization: Default::default(),
679 write_uqff,
680 from_uqff: from_uqff.map(|x| {
681 x.split(UQFF_MULTI_FILE_DELIMITER)
682 .map(PathBuf::from_str)
683 .map(|x| x.unwrap())
684 .collect::<Vec<_>>()
685 }),
686 imatrix: None,
687 calibration_file: None,
688 hf_cache_path,
689 matformer_config_path: None,
690 matformer_slice_name: None,
691 },
692 args.chat_template,
693 args.tokenizer_json,
694 model_id,
695 args.no_kv_cache,
696 args.jinja_explicit,
697 )
698 .with_xlora(
699 xlora_model_id,
700 serde_json::from_reader(
701 File::open(order.clone())
702 .unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
703 )?,
704 args.no_kv_cache,
705 tgt_non_granular_index,
706 )
707 .build(arch)?,
708 TomlModelSelected::Lora {
709 model_id,
710 adapter_model_ids,
711 arch,
712 dtype: _,
713 topology,
714 write_uqff,
715 from_uqff,
716 max_seq_len: _,
717 max_batch_size: _,
718 hf_cache_path,
719 } => NormalLoaderBuilder::new(
720 NormalSpecificConfig {
721 topology: Topology::from_option_path(topology)?,
722 organization: Default::default(),
723 write_uqff,
724 from_uqff: from_uqff.map(|x| {
725 x.split(UQFF_MULTI_FILE_DELIMITER)
726 .map(PathBuf::from_str)
727 .map(|x| x.unwrap())
728 .collect::<Vec<_>>()
729 }),
730 imatrix: None,
731 calibration_file: None,
732 hf_cache_path,
733 matformer_config_path: None,
734 matformer_slice_name: None,
735 },
736 args.chat_template,
737 args.tokenizer_json,
738 model_id,
739 args.no_kv_cache,
740 args.jinja_explicit,
741 )
742 .with_lora(
743 adapter_model_ids
744 .split(MULTI_LORA_DELIMITER)
745 .map(ToString::to_string)
746 .collect(),
747 )
748 .build(arch)?,
749 TomlModelSelected::GGUF {
750 tok_model_id,
751 quantized_model_id,
752 quantized_filename,
753 topology,
754 dtype: _,
755 max_seq_len: _,
756 max_batch_size: _,
757 } => GGUFLoaderBuilder::new(
758 args.chat_template,
759 Some(tok_model_id),
760 quantized_model_id,
761 quantized_filename
762 .split(GGUF_MULTI_FILE_DELIMITER)
763 .map(ToOwned::to_owned)
764 .collect::<Vec<_>>(),
765 GGUFSpecificConfig {
766 topology: Topology::from_option_path(topology)?,
767 },
768 args.no_kv_cache,
769 args.jinja_explicit,
770 )
771 .build(),
772 TomlModelSelected::XLoraGGUF {
773 tok_model_id,
774 quantized_model_id,
775 quantized_filename,
776 xlora_model_id,
777 order,
778 tgt_non_granular_index,
779 topology,
780 dtype: _,
781 max_seq_len: _,
782 max_batch_size: _,
783 } => GGUFLoaderBuilder::new(
784 args.chat_template,
785 tok_model_id,
786 quantized_model_id,
787 quantized_filename
788 .split(GGUF_MULTI_FILE_DELIMITER)
789 .map(ToOwned::to_owned)
790 .collect::<Vec<_>>(),
791 GGUFSpecificConfig {
792 topology: Topology::from_option_path(topology)?,
793 },
794 args.no_kv_cache,
795 args.jinja_explicit,
796 )
797 .with_xlora(
798 xlora_model_id,
799 serde_json::from_reader(
800 File::open(order.clone())
801 .unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
802 )?,
803 args.no_kv_cache,
804 tgt_non_granular_index,
805 )
806 .build(),
807 TomlModelSelected::LoraGGUF {
808 tok_model_id,
809 quantized_model_id,
810 quantized_filename,
811 adapters_model_id,
812 order,
813 topology,
814 ..
815 } => GGUFLoaderBuilder::new(
816 args.chat_template,
817 tok_model_id,
818 quantized_model_id,
819 quantized_filename
820 .split(GGUF_MULTI_FILE_DELIMITER)
821 .map(ToOwned::to_owned)
822 .collect::<Vec<_>>(),
823 GGUFSpecificConfig {
824 topology: Topology::from_option_path(topology)?,
825 },
826 args.no_kv_cache,
827 args.jinja_explicit,
828 )
829 .with_lora(
830 adapters_model_id,
831 serde_json::from_reader(
832 File::open(order.clone())
833 .unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
834 )?,
835 )
836 .build(),
837 TomlModelSelected::GGML {
838 tok_model_id,
839 quantized_model_id,
840 quantized_filename,
841 gqa,
842 topology,
843 dtype: _,
844 max_seq_len: _,
845 max_batch_size: _,
846 } => GGMLLoaderBuilder::new(
847 GGMLSpecificConfig {
848 gqa,
849 topology: Topology::from_option_path(topology)?,
850 },
851 args.chat_template,
852 args.tokenizer_json,
853 Some(tok_model_id),
854 quantized_model_id,
855 quantized_filename,
856 args.no_kv_cache,
857 args.jinja_explicit,
858 )
859 .build(),
860 TomlModelSelected::XLoraGGML {
861 tok_model_id,
862 quantized_model_id,
863 quantized_filename,
864 xlora_model_id,
865 order,
866 tgt_non_granular_index,
867 gqa,
868 topology,
869 dtype: _,
870 max_seq_len: _,
871 max_batch_size: _,
872 } => GGMLLoaderBuilder::new(
873 GGMLSpecificConfig {
874 gqa,
875 topology: Topology::from_option_path(topology)?,
876 },
877 args.chat_template,
878 args.tokenizer_json,
879 tok_model_id,
880 quantized_model_id,
881 quantized_filename,
882 args.no_kv_cache,
883 args.jinja_explicit,
884 )
885 .with_xlora(
886 xlora_model_id,
887 serde_json::from_reader(
888 File::open(order.clone())
889 .unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
890 )?,
891 args.no_kv_cache,
892 tgt_non_granular_index,
893 )
894 .build(),
895 TomlModelSelected::LoraGGML {
896 tok_model_id,
897 quantized_model_id,
898 quantized_filename,
899 adapters_model_id,
900 order,
901 gqa,
902 topology,
903 dtype: _,
904 max_seq_len: _,
905 max_batch_size: _,
906 } => GGMLLoaderBuilder::new(
907 GGMLSpecificConfig {
908 gqa,
909 topology: Topology::from_option_path(topology)?,
910 },
911 args.chat_template,
912 args.tokenizer_json,
913 tok_model_id,
914 quantized_model_id,
915 quantized_filename,
916 args.no_kv_cache,
917 args.jinja_explicit,
918 )
919 .with_lora(
920 adapters_model_id,
921 serde_json::from_reader(
922 File::open(order.clone())
923 .unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
924 )?,
925 )
926 .build(),
927 TomlModelSelected::MultimodalPlain {
928 model_id,
929 arch,
930 dtype: _,
931 topology,
932 write_uqff,
933 from_uqff,
934 max_edge,
935 calibration_file,
936 max_seq_len: _,
937 max_batch_size: _,
938 max_num_images: _,
939 max_image_length: _,
940 imatrix,
941 hf_cache_path,
942 organization,
943 } => MultimodalLoaderBuilder::new(
944 MultimodalSpecificConfig {
945 topology: Topology::from_option_path(topology)?,
946 write_uqff,
947 from_uqff: from_uqff.map(|x| {
948 x.split(UQFF_MULTI_FILE_DELIMITER)
949 .map(PathBuf::from_str)
950 .map(|x| x.unwrap())
951 .collect::<Vec<_>>()
952 }),
953 max_edge,
954 calibration_file,
955 imatrix,
956 hf_cache_path,
957 matformer_config_path: None,
958 matformer_slice_name: None,
959 organization: organization.unwrap_or_default(),
960 },
961 args.chat_template,
962 args.tokenizer_json,
963 Some(model_id),
964 args.jinja_explicit,
965 )
966 .build(arch),
967 TomlModelSelected::Embedding {
968 model_id,
969 tokenizer_json,
970 arch,
971 dtype: _,
972 topology,
973 write_uqff,
974 from_uqff,
975 hf_cache_path,
976 } => EmbeddingLoaderBuilder::new(
977 EmbeddingSpecificConfig {
978 topology: Topology::from_option_path(topology)?,
979 write_uqff,
980 from_uqff: from_uqff.map(|x| {
981 x.split(UQFF_MULTI_FILE_DELIMITER)
982 .map(PathBuf::from_str)
983 .map(|x| x.unwrap())
984 .collect::<Vec<_>>()
985 }),
986 hf_cache_path,
987 },
988 tokenizer_json,
989 Some(model_id),
990 )
991 .build(arch),
992 };
993 Ok(loader)
994}
995
996impl TryInto<Box<dyn Loader>> for (TomlSelector, TomlLoaderArgs) {
997 type Error = anyhow::Error;
998 fn try_into(self) -> Result<Box<dyn Loader>, Self::Error> {
999 let (selector, args) = self;
1000 let args = TomlLoaderInnerParams {
1001 chat_template: args.chat_template,
1002 no_kv_cache: args.no_kv_cache,
1003 tokenizer_json: selector.tokenizer_json,
1004 jinja_explicit: args.jinja_explicit,
1005 };
1006 if selector.speculative.is_some() {
1007 anyhow::bail!(
1008 "legacy target/draft speculative decoding in TOML configs was removed; use MTP through --mtp-model or the MTP API instead"
1009 );
1010 }
1011 let loader = loader_from_selected(args.clone(), selector.model)?;
1012 let loader = if let Some(AnyMoeTomlModelSelected {
1013 config,
1014 dataset_json,
1015 prefix,
1016 mlp,
1017 model_ids,
1018 layers,
1019 }) = selector.anymoe
1020 {
1021 Box::new(AnyMoeLoader {
1022 target: loader,
1023 config,
1024 path: dataset_json,
1025 prefix,
1026 mlp,
1027 model_ids,
1028 layers,
1029 })
1030 } else {
1031 loader
1032 };
1033 Ok(loader)
1034 }
1035}