Skip to main content

hanzo_engine/
toml_selector.rs

1use std::{fs::File, path::PathBuf, str::FromStr};
2
3use hanzo_quant::MULTI_LORA_DELIMITER;
4use serde::Deserialize;
5
6use crate::{
7    amoe::AnyMoeConfig,
8    pipeline::{EmbeddingLoaderType, IsqOrganization},
9    AnyMoeLoader, AutoDeviceMapParams, EmbeddingLoaderBuilder, EmbeddingSpecificConfig,
10    GGMLLoaderBuilder, GGMLSpecificConfig, GGUFLoaderBuilder, GGUFSpecificConfig, Loader,
11    ModelDType, MultimodalLoaderBuilder, MultimodalLoaderType, MultimodalSpecificConfig,
12    NormalLoaderBuilder, NormalLoaderType, NormalSpecificConfig, Topology,
13    GGUF_MULTI_FILE_DELIMITER, UQFF_MULTI_FILE_DELIMITER,
14};
15
16fn default_one() -> usize {
17    1
18}
19
20fn default_dtype() -> ModelDType {
21    ModelDType::Auto
22}
23
24fn default_empty_vec_usize() -> Vec<usize> {
25    Vec::new()
26}
27
28fn default_max_seq_len() -> usize {
29    AutoDeviceMapParams::DEFAULT_MAX_SEQ_LEN
30}
31
32fn default_max_batch_size() -> usize {
33    AutoDeviceMapParams::DEFAULT_MAX_BATCH_SIZE
34}
35
36fn default_max_num_images() -> usize {
37    AutoDeviceMapParams::DEFAULT_MAX_NUM_IMAGES
38}
39
40fn default_max_image_length() -> usize {
41    AutoDeviceMapParams::DEFAULT_MAX_IMAGE_LENGTH
42}
43
44#[derive(Debug, Deserialize)]
45#[serde(untagged)]
46pub enum TomlModelSelected {
47    /// Select a plain model, without quantization or adapters
48    Plain {
49        /// Model ID to load from. This may be a HF hub repo or a local path.
50        model_id: String,
51
52        /// The architecture of the model.
53        arch: Option<NormalLoaderType>,
54
55        /// Model data type. Defaults to `auto`.
56        #[serde(default = "default_dtype")]
57        dtype: ModelDType,
58
59        /// Path to a topology YAML file.
60        topology: Option<String>,
61
62        /// ISQ organization: `default` or `moqe` (Mixture of Quantized Experts: https://arxiv.org/abs/2310.02410).
63        organization: Option<IsqOrganization>,
64
65        /// UQFF path to write to.
66        write_uqff: Option<PathBuf>,
67
68        /// UQFF path to load from. If provided, this takes precedence over applying ISQ.
69        from_uqff: Option<String>,
70
71        /// .imatrix file to enhance GGUF quantizations with.
72        /// Incompatible with `--imatrix/-i`
73        imatrix: Option<PathBuf>,
74
75        /// Generate and utilize an imatrix to enhance GGUF quantizations.
76        /// Incompatible with `--imatrix/-i`
77        calibration_file: Option<PathBuf>,
78
79        /// Maximum prompt sequence length to expect for this model. This affects automatic device mapping but is not a hard limit.
80        #[serde(default = "default_max_seq_len")]
81        max_seq_len: usize,
82
83        /// Maximum prompt batch size to expect for this model. This affects automatic device mapping but is not a hard limit.
84        #[serde(default = "default_max_batch_size")]
85        max_batch_size: usize,
86
87        /// Cache path for Hugging Face models downloaded locally
88        hf_cache_path: Option<PathBuf>,
89    },
90
91    /// Select an X-LoRA architecture
92    XLora {
93        /// Force a base model ID to load from instead of using the ordering file. This may be a HF hub repo or a local path.
94        model_id: Option<String>,
95
96        /// Model ID to load X-LoRA from. This may be a HF hub repo or a local path.
97        xlora_model_id: String,
98
99        /// Ordering JSON file
100        order: String,
101
102        /// Index of completion tokens to generate scalings up until. If this is 1, then there will be one completion token generated before it is cached.
103        /// This makes the maximum running sequences 1.
104        tgt_non_granular_index: Option<usize>,
105
106        /// The architecture of the model.
107        arch: Option<NormalLoaderType>,
108
109        /// Model data type. Defaults to `auto`.
110        #[serde(default = "default_dtype")]
111        dtype: ModelDType,
112
113        /// Path to a topology YAML file.
114        topology: Option<String>,
115
116        /// UQFF path to write to.
117        write_uqff: Option<PathBuf>,
118
119        /// UQFF path to load from. If provided, this takes precedence over applying ISQ.
120        from_uqff: Option<String>,
121
122        /// Maximum prompt sequence length to expect for this model. This affects automatic device mapping but is not a hard limit.
123        #[serde(default = "default_max_seq_len")]
124        max_seq_len: usize,
125
126        /// Maximum prompt batch size to expect for this model. This affects automatic device mapping but is not a hard limit.
127        #[serde(default = "default_max_batch_size")]
128        max_batch_size: usize,
129
130        /// Cache path for Hugging Face models downloaded locally
131        hf_cache_path: Option<PathBuf>,
132    },
133
134    /// Select a LoRA architecture
135    Lora {
136        /// Force a base model ID to load from instead of using the ordering file. This may be a HF hub repo or a local path.
137        model_id: Option<String>,
138
139        /// Model IDs to load LoRA from. This may be a HF hub repo or a local path. Specify multiple with a semicolon.
140        adapter_model_ids: String,
141
142        /// The architecture of the model.
143        arch: Option<NormalLoaderType>,
144
145        /// Model data type. Defaults to `auto`.
146        #[serde(default = "default_dtype")]
147        dtype: ModelDType,
148
149        /// Path to a topology YAML file.
150        topology: Option<String>,
151
152        /// UQFF path to write to.
153        write_uqff: Option<PathBuf>,
154
155        /// UQFF path to load from. If provided, this takes precedence over applying ISQ.
156        from_uqff: Option<String>,
157
158        /// Maximum prompt sequence length to expect for this model. This affects automatic device mapping but is not a hard limit.
159        #[serde(default = "default_max_seq_len")]
160        max_seq_len: usize,
161
162        /// Maximum prompt batch size to expect for this model. This affects automatic device mapping but is not a hard limit.
163        #[serde(default = "default_max_batch_size")]
164        max_batch_size: usize,
165
166        /// Cache path for Hugging Face models downloaded locally
167        hf_cache_path: Option<PathBuf>,
168    },
169
170    /// Select a GGUF model.
171    #[allow(clippy::upper_case_acronyms)]
172    GGUF {
173        /// `tok_model_id` is the local or remote model ID where you can find a `tokenizer_config.json` file.
174        /// If the `chat_template` is specified, then it will be treated as a path and used over remote files,
175        /// removing all remote accesses.
176        tok_model_id: String,
177
178        /// Quantized model ID to find the `quantized_filename`.
179        /// This may be a HF hub repo or a local path.
180        quantized_model_id: String,
181
182        /// Quantized filename(s).
183        /// May be a single filename, or use a delimiter of " " (a single space) for multiple files.
184        quantized_filename: String,
185
186        /// Model data type. Defaults to `auto`.
187        #[serde(default = "default_dtype")]
188        dtype: ModelDType,
189
190        /// Path to a topology YAML file.
191        topology: Option<String>,
192
193        /// Maximum prompt sequence length to expect for this model. This affects automatic device mapping but is not a hard limit.
194        #[serde(default = "default_max_seq_len")]
195        max_seq_len: usize,
196
197        /// Maximum prompt batch size to expect for this model. This affects automatic device mapping but is not a hard limit.
198        #[serde(default = "default_max_batch_size")]
199        max_batch_size: usize,
200    },
201
202    /// Select a GGUF model with X-LoRA.
203    XLoraGGUF {
204        /// `tok_model_id` is the local or remote model ID where you can find a `tokenizer_config.json` file.
205        /// If the `chat_template` is specified, then it will be treated as a path and used over remote files,
206        /// removing all remote accesses.
207        tok_model_id: Option<String>,
208
209        /// Quantized model ID to find the `quantized_filename`.
210        /// This may be a HF hub repo or a local path.
211        quantized_model_id: String,
212
213        /// Quantized filename(s).
214        /// May be a single filename, or use a delimiter of " " (a single space) for multiple files.
215        quantized_filename: String,
216
217        /// Model ID to load X-LoRA from. This may be a HF hub repo or a local path.
218        xlora_model_id: String,
219
220        /// Ordering JSON file
221        order: String,
222
223        /// Index of completion tokens to generate scalings up until. If this is 1, then there will be one completion token generated before it is cached.
224        /// This makes the maximum running sequences 1.
225        tgt_non_granular_index: Option<usize>,
226
227        /// Model data type. Defaults to `auto`.
228        #[serde(default = "default_dtype")]
229        dtype: ModelDType,
230
231        /// Path to a topology YAML file.
232        topology: Option<String>,
233
234        /// Maximum prompt sequence length to expect for this model. This affects automatic device mapping but is not a hard limit.
235        #[serde(default = "default_max_seq_len")]
236        max_seq_len: usize,
237
238        /// Maximum prompt batch size to expect for this model. This affects automatic device mapping but is not a hard limit.
239        #[serde(default = "default_max_batch_size")]
240        max_batch_size: usize,
241    },
242
243    /// Select a GGUF model with LoRA.
244    LoraGGUF {
245        /// `tok_model_id` is the local or remote model ID where you can find a `tokenizer_config.json` file.
246        /// If the `chat_template` is specified, then it will be treated as a path and used over remote files,
247        /// removing all remote accesses.
248        tok_model_id: Option<String>,
249
250        /// Quantized model ID to find the `quantized_filename`.
251        /// This may be a HF hub repo or a local path.
252        quantized_model_id: String,
253
254        /// Quantized filename(s).
255        /// May be a single filename, or use a delimiter of " " (a single space) for multiple files.
256        quantized_filename: String,
257
258        /// Model ID to load LoRA from. This may be a HF hub repo or a local path.
259        adapters_model_id: String,
260
261        /// Ordering JSON file
262        order: String,
263
264        /// Model data type. Defaults to `auto`.
265        #[serde(default = "default_dtype")]
266        dtype: ModelDType,
267
268        /// Path to a topology YAML file.
269        topology: Option<String>,
270
271        /// Maximum prompt sequence length to expect for this model. This affects automatic device mapping but is not a hard limit.
272        #[serde(default = "default_max_seq_len")]
273        max_seq_len: usize,
274
275        /// Maximum prompt batch size to expect for this model. This affects automatic device mapping but is not a hard limit.
276        #[serde(default = "default_max_batch_size")]
277        max_batch_size: usize,
278    },
279
280    /// Select a GGML model.
281    #[allow(clippy::upper_case_acronyms)]
282    GGML {
283        /// Model ID to load the tokenizer from. This may be a HF hub repo or a local path.
284        tok_model_id: String,
285
286        /// Quantized model ID to find the `quantized_filename`.
287        /// This may be a HF hub repo or a local path.
288        quantized_model_id: String,
289
290        /// Quantized filename.
291        quantized_filename: String,
292
293        /// GQA value
294        #[serde(default = "default_one")]
295        gqa: usize,
296
297        /// Model data type. Defaults to `auto`.
298        #[serde(default = "default_dtype")]
299        dtype: ModelDType,
300
301        /// Path to a topology YAML file.
302        topology: Option<String>,
303
304        /// Maximum prompt sequence length to expect for this model. This affects automatic device mapping but is not a hard limit.
305        #[serde(default = "default_max_seq_len")]
306        max_seq_len: usize,
307
308        /// Maximum prompt batch size to expect for this model. This affects automatic device mapping but is not a hard limit.
309        #[serde(default = "default_max_batch_size")]
310        max_batch_size: usize,
311    },
312
313    /// Select a GGML model with X-LoRA.
314    XLoraGGML {
315        /// Model ID to load the tokenizer from. This may be a HF hub repo or a local path.
316        tok_model_id: Option<String>,
317
318        /// Quantized model ID to find the `quantized_filename`.
319        /// This may be a HF hub repo or a local path.
320        quantized_model_id: String,
321
322        /// Quantized filename.
323        quantized_filename: String,
324
325        /// Model ID to load X-LoRA from. This may be a HF hub repo or a local path.
326        xlora_model_id: String,
327
328        /// Ordering JSON file
329        order: String,
330
331        /// Index of completion tokens to generate scalings up until. If this is 1, then there will be one completion token generated before it is cached.
332        /// This makes the maximum running sequences 1.
333        tgt_non_granular_index: Option<usize>,
334
335        /// GQA value
336        #[serde(default = "default_one")]
337        gqa: usize,
338
339        /// Model data type. Defaults to `auto`.
340        #[serde(default = "default_dtype")]
341        dtype: ModelDType,
342
343        /// Path to a topology YAML file.
344        topology: Option<String>,
345
346        /// Maximum prompt sequence length to expect for this model. This affects automatic device mapping but is not a hard limit.
347        #[serde(default = "default_max_seq_len")]
348        max_seq_len: usize,
349
350        /// Maximum prompt batch size to expect for this model. This affects automatic device mapping but is not a hard limit.
351        #[serde(default = "default_max_batch_size")]
352        max_batch_size: usize,
353    },
354
355    /// Select a GGML model with LoRA.
356    LoraGGML {
357        /// Model ID to load the tokenizer from. This may be a HF hub repo or a local path.
358        tok_model_id: Option<String>,
359
360        /// Quantized model ID to find the `quantized_filename`.
361        /// This may be a HF hub repo or a local path.
362        quantized_model_id: String,
363
364        /// Quantized filename.
365        quantized_filename: String,
366
367        /// Model ID to load LoRA from. This may be a HF hub repo or a local path.
368        adapters_model_id: String,
369
370        /// Ordering JSON file
371        order: String,
372
373        /// GQA value
374        #[serde(default = "default_one")]
375        gqa: usize,
376
377        /// Model data type. Defaults to `auto`.
378        #[serde(default = "default_dtype")]
379        dtype: ModelDType,
380
381        /// Path to a topology YAML file.
382        topology: Option<String>,
383
384        /// Maximum prompt sequence length to expect for this model. This affects automatic device mapping but is not a hard limit.
385        #[serde(default = "default_max_seq_len")]
386        max_seq_len: usize,
387
388        /// Maximum prompt batch size to expect for this model. This affects automatic device mapping but is not a hard limit.
389        #[serde(default = "default_max_batch_size")]
390        max_batch_size: usize,
391    },
392
393    /// Select a multimodal plain model, without quantization or adapters
394    MultimodalPlain {
395        /// Model ID to load from. This may be a HF hub repo or a local path.
396        model_id: String,
397
398        /// The architecture of the model.
399        arch: Option<MultimodalLoaderType>,
400
401        /// Model data type. Defaults to `auto`.
402        #[serde(default = "default_dtype")]
403        dtype: ModelDType,
404
405        /// Path to a topology YAML file.
406        topology: Option<String>,
407
408        /// UQFF path to write to.
409        write_uqff: Option<PathBuf>,
410
411        /// UQFF path to load from. If provided, this takes precedence over applying ISQ.
412        from_uqff: Option<String>,
413
414        /// Automatically resize and pad images to this maximum edge length. Aspect ratio is preserved.
415        /// This is only supported on the Qwen2-VL and Idefics 2 models. Others handle this internally.
416        max_edge: Option<u32>,
417
418        /// Generate and utilize an imatrix to enhance GGUF quantizations.
419        calibration_file: Option<PathBuf>,
420
421        /// .cimatrix file to enhance GGUF quantizations with. This must be a .cimatrix file.
422        imatrix: Option<PathBuf>,
423
424        /// Maximum prompt sequence length to expect for this model. This affects automatic device mapping but is not a hard limit.
425        #[serde(default = "default_max_seq_len")]
426        max_seq_len: usize,
427
428        /// Maximum prompt batch size to expect for this model. This affects automatic device mapping but is not a hard limit.
429        #[serde(default = "default_max_batch_size")]
430        max_batch_size: usize,
431
432        /// Maximum prompt number of images to expect for this model. This affects automatic device mapping but is not a hard limit.
433        #[serde(default = "default_max_num_images")]
434        max_num_images: usize,
435
436        /// Maximum expected image size will have this edge length on both edges.
437        /// This affects automatic device mapping but is not a hard limit.
438        #[serde(default = "default_max_image_length")]
439        max_image_length: usize,
440
441        /// Cache path for Hugging Face models downloaded locally
442        hf_cache_path: Option<PathBuf>,
443
444        /// ISQ organization: `default` or `moqe`.
445        organization: Option<IsqOrganization>,
446    },
447
448    /// Select an embedding model, without quantization or adapters
449    Embedding {
450        /// Model ID to load from. This may be a HF hub repo or a local path.
451        model_id: String,
452
453        /// Path to local tokenizer.json file. If this is specified it is used over any remote file.
454        #[serde(default)]
455        tokenizer_json: Option<String>,
456
457        /// The architecture of the model.
458        #[serde(default)]
459        arch: Option<EmbeddingLoaderType>,
460
461        /// Model data type. Defaults to `auto`.
462        #[serde(default = "default_dtype")]
463        dtype: ModelDType,
464
465        /// Path to a topology YAML file.
466        #[serde(default)]
467        topology: Option<String>,
468
469        /// UQFF path to write to.
470        #[serde(default)]
471        write_uqff: Option<PathBuf>,
472
473        /// UQFF path to load from. If provided, this takes precedence over applying ISQ. Specify multiple files using a semicolon delimiter (;)
474        #[serde(default)]
475        from_uqff: Option<String>,
476
477        /// Cache path for Hugging Face models downloaded locally
478        #[serde(default)]
479        hf_cache_path: Option<PathBuf>,
480    },
481}
482
483#[derive(Deserialize)]
484pub struct AnyMoeTomlModelSelected {
485    /// Config
486    config: AnyMoeConfig,
487
488    /// Base model
489    dataset_json: String,
490
491    /// Prefix of the mlp key (the part before the layer number: "a.b.c" in "a.b.c.0.mlp")
492    prefix: String,
493
494    /// Name of the mlp key (the part before the layer number: "mlp" in "a.b.c.0.mlp")
495    mlp: String,
496
497    /// Expert model ids
498    model_ids: Vec<String>,
499
500    /// Layer ids (zero indexed) of layers to apply AnyMoE to, if empty will use all
501    #[serde(default = "default_empty_vec_usize")]
502    layers: Vec<usize>,
503}
504
505#[derive(Deserialize)]
506pub struct TomlSelector {
507    /// Path to local tokenizer.json file. If this is specified it is used over any remote file.
508    tokenizer_json: Option<String>,
509
510    /// Selected model
511    model: TomlModelSelected,
512
513    /// Legacy target/draft speculative decoding was removed. Keep this field
514    /// only to reject old configs explicitly instead of silently ignoring them.
515    #[serde(default)]
516    speculative: Option<serde::de::IgnoredAny>,
517
518    /// AnyMoE config
519    anymoe: Option<AnyMoeTomlModelSelected>,
520}
521
522#[derive(Clone)]
523struct TomlLoaderInnerParams {
524    chat_template: Option<String>,
525    no_kv_cache: bool,
526    tokenizer_json: Option<String>,
527    jinja_explicit: Option<String>,
528}
529
530pub struct TomlLoaderArgs {
531    pub chat_template: Option<String>,
532    pub no_kv_cache: bool,
533    pub jinja_explicit: Option<String>,
534}
535
536pub fn get_toml_selected_model_dtype(model: &TomlSelector) -> ModelDType {
537    match model.model {
538        TomlModelSelected::Plain { dtype, .. }
539        | TomlModelSelected::Lora { dtype, .. }
540        | TomlModelSelected::XLora { dtype, .. }
541        | TomlModelSelected::MultimodalPlain { dtype, .. }
542        | TomlModelSelected::GGUF { dtype, .. }
543        | TomlModelSelected::GGML { dtype, .. }
544        | TomlModelSelected::XLoraGGUF { dtype, .. }
545        | TomlModelSelected::XLoraGGML { dtype, .. }
546        | TomlModelSelected::LoraGGUF { dtype, .. }
547        | TomlModelSelected::LoraGGML { dtype, .. }
548        | TomlModelSelected::Embedding { dtype, .. } => dtype,
549    }
550}
551
552pub fn get_toml_selected_model_device_map_params(
553    model: &TomlSelector,
554) -> anyhow::Result<AutoDeviceMapParams> {
555    match model.model {
556        TomlModelSelected::Plain {
557            max_seq_len,
558            max_batch_size,
559            ..
560        }
561        | TomlModelSelected::Lora {
562            max_seq_len,
563            max_batch_size,
564            ..
565        }
566        | TomlModelSelected::XLora {
567            max_seq_len,
568            max_batch_size,
569            ..
570        }
571        | TomlModelSelected::GGML {
572            max_seq_len,
573            max_batch_size,
574            ..
575        }
576        | TomlModelSelected::GGUF {
577            max_seq_len,
578            max_batch_size,
579            ..
580        }
581        | TomlModelSelected::XLoraGGUF {
582            max_seq_len,
583            max_batch_size,
584            ..
585        }
586        | TomlModelSelected::XLoraGGML {
587            max_seq_len,
588            max_batch_size,
589            ..
590        }
591        | TomlModelSelected::LoraGGUF {
592            max_seq_len,
593            max_batch_size,
594            ..
595        }
596        | TomlModelSelected::LoraGGML {
597            max_seq_len,
598            max_batch_size,
599            ..
600        } => Ok(AutoDeviceMapParams::Text {
601            max_seq_len,
602            max_batch_size,
603        }),
604        TomlModelSelected::Embedding { .. } => Ok(AutoDeviceMapParams::default_text()),
605        TomlModelSelected::MultimodalPlain {
606            max_seq_len,
607            max_batch_size,
608            max_image_length,
609            max_num_images,
610            ..
611        } => Ok(AutoDeviceMapParams::Multimodal {
612            max_seq_len,
613            max_batch_size,
614            max_image_shape: (max_image_length, max_image_length),
615            max_num_images,
616        }),
617    }
618}
619
620fn loader_from_selected(
621    args: TomlLoaderInnerParams,
622    model: TomlModelSelected,
623) -> anyhow::Result<Box<dyn Loader>> {
624    let loader: Box<dyn Loader> = match model {
625        TomlModelSelected::Plain {
626            model_id,
627            arch,
628            dtype: _,
629            topology,
630            organization,
631            write_uqff,
632            from_uqff,
633            imatrix,
634            calibration_file,
635            max_seq_len: _,
636            max_batch_size: _,
637            hf_cache_path,
638        } => NormalLoaderBuilder::new(
639            NormalSpecificConfig {
640                topology: Topology::from_option_path(topology)?,
641                organization: organization.unwrap_or_default(),
642                write_uqff,
643                from_uqff: from_uqff.map(|x| {
644                    x.split(UQFF_MULTI_FILE_DELIMITER)
645                        .map(PathBuf::from_str)
646                        .map(|x| x.unwrap())
647                        .collect::<Vec<_>>()
648                }),
649                imatrix,
650                calibration_file,
651                hf_cache_path,
652                matformer_config_path: None,
653                matformer_slice_name: None,
654            },
655            args.chat_template,
656            args.tokenizer_json,
657            Some(model_id),
658            args.no_kv_cache,
659            args.jinja_explicit,
660        )
661        .build(arch)?,
662        TomlModelSelected::XLora {
663            model_id,
664            xlora_model_id,
665            order,
666            tgt_non_granular_index,
667            arch,
668            dtype: _,
669            topology,
670            write_uqff,
671            from_uqff,
672            max_seq_len: _,
673            max_batch_size: _,
674            hf_cache_path,
675        } => NormalLoaderBuilder::new(
676            NormalSpecificConfig {
677                topology: Topology::from_option_path(topology)?,
678                organization: Default::default(),
679                write_uqff,
680                from_uqff: from_uqff.map(|x| {
681                    x.split(UQFF_MULTI_FILE_DELIMITER)
682                        .map(PathBuf::from_str)
683                        .map(|x| x.unwrap())
684                        .collect::<Vec<_>>()
685                }),
686                imatrix: None,
687                calibration_file: None,
688                hf_cache_path,
689                matformer_config_path: None,
690                matformer_slice_name: None,
691            },
692            args.chat_template,
693            args.tokenizer_json,
694            model_id,
695            args.no_kv_cache,
696            args.jinja_explicit,
697        )
698        .with_xlora(
699            xlora_model_id,
700            serde_json::from_reader(
701                File::open(order.clone())
702                    .unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
703            )?,
704            args.no_kv_cache,
705            tgt_non_granular_index,
706        )
707        .build(arch)?,
708        TomlModelSelected::Lora {
709            model_id,
710            adapter_model_ids,
711            arch,
712            dtype: _,
713            topology,
714            write_uqff,
715            from_uqff,
716            max_seq_len: _,
717            max_batch_size: _,
718            hf_cache_path,
719        } => NormalLoaderBuilder::new(
720            NormalSpecificConfig {
721                topology: Topology::from_option_path(topology)?,
722                organization: Default::default(),
723                write_uqff,
724                from_uqff: from_uqff.map(|x| {
725                    x.split(UQFF_MULTI_FILE_DELIMITER)
726                        .map(PathBuf::from_str)
727                        .map(|x| x.unwrap())
728                        .collect::<Vec<_>>()
729                }),
730                imatrix: None,
731                calibration_file: None,
732                hf_cache_path,
733                matformer_config_path: None,
734                matformer_slice_name: None,
735            },
736            args.chat_template,
737            args.tokenizer_json,
738            model_id,
739            args.no_kv_cache,
740            args.jinja_explicit,
741        )
742        .with_lora(
743            adapter_model_ids
744                .split(MULTI_LORA_DELIMITER)
745                .map(ToString::to_string)
746                .collect(),
747        )
748        .build(arch)?,
749        TomlModelSelected::GGUF {
750            tok_model_id,
751            quantized_model_id,
752            quantized_filename,
753            topology,
754            dtype: _,
755            max_seq_len: _,
756            max_batch_size: _,
757        } => GGUFLoaderBuilder::new(
758            args.chat_template,
759            Some(tok_model_id),
760            quantized_model_id,
761            quantized_filename
762                .split(GGUF_MULTI_FILE_DELIMITER)
763                .map(ToOwned::to_owned)
764                .collect::<Vec<_>>(),
765            GGUFSpecificConfig {
766                topology: Topology::from_option_path(topology)?,
767            },
768            args.no_kv_cache,
769            args.jinja_explicit,
770        )
771        .build(),
772        TomlModelSelected::XLoraGGUF {
773            tok_model_id,
774            quantized_model_id,
775            quantized_filename,
776            xlora_model_id,
777            order,
778            tgt_non_granular_index,
779            topology,
780            dtype: _,
781            max_seq_len: _,
782            max_batch_size: _,
783        } => GGUFLoaderBuilder::new(
784            args.chat_template,
785            tok_model_id,
786            quantized_model_id,
787            quantized_filename
788                .split(GGUF_MULTI_FILE_DELIMITER)
789                .map(ToOwned::to_owned)
790                .collect::<Vec<_>>(),
791            GGUFSpecificConfig {
792                topology: Topology::from_option_path(topology)?,
793            },
794            args.no_kv_cache,
795            args.jinja_explicit,
796        )
797        .with_xlora(
798            xlora_model_id,
799            serde_json::from_reader(
800                File::open(order.clone())
801                    .unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
802            )?,
803            args.no_kv_cache,
804            tgt_non_granular_index,
805        )
806        .build(),
807        TomlModelSelected::LoraGGUF {
808            tok_model_id,
809            quantized_model_id,
810            quantized_filename,
811            adapters_model_id,
812            order,
813            topology,
814            ..
815        } => GGUFLoaderBuilder::new(
816            args.chat_template,
817            tok_model_id,
818            quantized_model_id,
819            quantized_filename
820                .split(GGUF_MULTI_FILE_DELIMITER)
821                .map(ToOwned::to_owned)
822                .collect::<Vec<_>>(),
823            GGUFSpecificConfig {
824                topology: Topology::from_option_path(topology)?,
825            },
826            args.no_kv_cache,
827            args.jinja_explicit,
828        )
829        .with_lora(
830            adapters_model_id,
831            serde_json::from_reader(
832                File::open(order.clone())
833                    .unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
834            )?,
835        )
836        .build(),
837        TomlModelSelected::GGML {
838            tok_model_id,
839            quantized_model_id,
840            quantized_filename,
841            gqa,
842            topology,
843            dtype: _,
844            max_seq_len: _,
845            max_batch_size: _,
846        } => GGMLLoaderBuilder::new(
847            GGMLSpecificConfig {
848                gqa,
849                topology: Topology::from_option_path(topology)?,
850            },
851            args.chat_template,
852            args.tokenizer_json,
853            Some(tok_model_id),
854            quantized_model_id,
855            quantized_filename,
856            args.no_kv_cache,
857            args.jinja_explicit,
858        )
859        .build(),
860        TomlModelSelected::XLoraGGML {
861            tok_model_id,
862            quantized_model_id,
863            quantized_filename,
864            xlora_model_id,
865            order,
866            tgt_non_granular_index,
867            gqa,
868            topology,
869            dtype: _,
870            max_seq_len: _,
871            max_batch_size: _,
872        } => GGMLLoaderBuilder::new(
873            GGMLSpecificConfig {
874                gqa,
875                topology: Topology::from_option_path(topology)?,
876            },
877            args.chat_template,
878            args.tokenizer_json,
879            tok_model_id,
880            quantized_model_id,
881            quantized_filename,
882            args.no_kv_cache,
883            args.jinja_explicit,
884        )
885        .with_xlora(
886            xlora_model_id,
887            serde_json::from_reader(
888                File::open(order.clone())
889                    .unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
890            )?,
891            args.no_kv_cache,
892            tgt_non_granular_index,
893        )
894        .build(),
895        TomlModelSelected::LoraGGML {
896            tok_model_id,
897            quantized_model_id,
898            quantized_filename,
899            adapters_model_id,
900            order,
901            gqa,
902            topology,
903            dtype: _,
904            max_seq_len: _,
905            max_batch_size: _,
906        } => GGMLLoaderBuilder::new(
907            GGMLSpecificConfig {
908                gqa,
909                topology: Topology::from_option_path(topology)?,
910            },
911            args.chat_template,
912            args.tokenizer_json,
913            tok_model_id,
914            quantized_model_id,
915            quantized_filename,
916            args.no_kv_cache,
917            args.jinja_explicit,
918        )
919        .with_lora(
920            adapters_model_id,
921            serde_json::from_reader(
922                File::open(order.clone())
923                    .unwrap_or_else(|_| panic!("Could not load ordering file at {order}")),
924            )?,
925        )
926        .build(),
927        TomlModelSelected::MultimodalPlain {
928            model_id,
929            arch,
930            dtype: _,
931            topology,
932            write_uqff,
933            from_uqff,
934            max_edge,
935            calibration_file,
936            max_seq_len: _,
937            max_batch_size: _,
938            max_num_images: _,
939            max_image_length: _,
940            imatrix,
941            hf_cache_path,
942            organization,
943        } => MultimodalLoaderBuilder::new(
944            MultimodalSpecificConfig {
945                topology: Topology::from_option_path(topology)?,
946                write_uqff,
947                from_uqff: from_uqff.map(|x| {
948                    x.split(UQFF_MULTI_FILE_DELIMITER)
949                        .map(PathBuf::from_str)
950                        .map(|x| x.unwrap())
951                        .collect::<Vec<_>>()
952                }),
953                max_edge,
954                calibration_file,
955                imatrix,
956                hf_cache_path,
957                matformer_config_path: None,
958                matformer_slice_name: None,
959                organization: organization.unwrap_or_default(),
960            },
961            args.chat_template,
962            args.tokenizer_json,
963            Some(model_id),
964            args.jinja_explicit,
965        )
966        .build(arch),
967        TomlModelSelected::Embedding {
968            model_id,
969            tokenizer_json,
970            arch,
971            dtype: _,
972            topology,
973            write_uqff,
974            from_uqff,
975            hf_cache_path,
976        } => EmbeddingLoaderBuilder::new(
977            EmbeddingSpecificConfig {
978                topology: Topology::from_option_path(topology)?,
979                write_uqff,
980                from_uqff: from_uqff.map(|x| {
981                    x.split(UQFF_MULTI_FILE_DELIMITER)
982                        .map(PathBuf::from_str)
983                        .map(|x| x.unwrap())
984                        .collect::<Vec<_>>()
985                }),
986                hf_cache_path,
987            },
988            tokenizer_json,
989            Some(model_id),
990        )
991        .build(arch),
992    };
993    Ok(loader)
994}
995
996impl TryInto<Box<dyn Loader>> for (TomlSelector, TomlLoaderArgs) {
997    type Error = anyhow::Error;
998    fn try_into(self) -> Result<Box<dyn Loader>, Self::Error> {
999        let (selector, args) = self;
1000        let args = TomlLoaderInnerParams {
1001            chat_template: args.chat_template,
1002            no_kv_cache: args.no_kv_cache,
1003            tokenizer_json: selector.tokenizer_json,
1004            jinja_explicit: args.jinja_explicit,
1005        };
1006        if selector.speculative.is_some() {
1007            anyhow::bail!(
1008                "legacy target/draft speculative decoding in TOML configs was removed; use MTP through --mtp-model or the MTP API instead"
1009            );
1010        }
1011        let loader = loader_from_selected(args.clone(), selector.model)?;
1012        let loader = if let Some(AnyMoeTomlModelSelected {
1013            config,
1014            dataset_json,
1015            prefix,
1016            mlp,
1017            model_ids,
1018            layers,
1019        }) = selector.anymoe
1020        {
1021            Box::new(AnyMoeLoader {
1022                target: loader,
1023                config,
1024                path: dataset_json,
1025                prefix,
1026                mlp,
1027                model_ids,
1028                layers,
1029            })
1030        } else {
1031            loader
1032        };
1033        Ok(loader)
1034    }
1035}