1use derive_builder::Builder;
2use hf_hub::api::sync::ApiError;
3use strum::{EnumDiscriminants, EnumString, VariantNames};
4use subenum::subenum;
5
6use crate::{
7 api::{Config, ConfigBuilder, ConfigBuilderError, ModelConfig, ModelConfigBuilder},
8 preset_builder::{
9 anima, anima2, chroma, chroma_radiance, diff_instruct_star, dream_shaper_xl_2_1_turbo,
10 ernie_image, ernie_image_turbo, flux_1_dev, flux_1_mini, flux_1_schnell, flux_2_dev,
11 flux_2_klein_4b, flux_2_klein_9b, flux_2_klein_base_4b, flux_2_klein_base_9b,
12 hi_dream_o1_image, hi_dream_o1_image_dev, juggernaut_xl_11, lens, lens_turbo,
13 long_cat_image, nitro_sd_realism, nitro_sd_vibrant, ovis_image, qwen_image, sd_turbo,
14 sdxl_base_1_0, sdxl_turbo_1_0, sdxs512_dream_shaper, segmind_vega, ssd_1b,
15 stable_diffusion_1_4, stable_diffusion_1_5, stable_diffusion_2_1,
16 stable_diffusion_3_5_large, stable_diffusion_3_5_large_turbo, stable_diffusion_3_5_medium,
17 stable_diffusion_3_medium, twinflow_z_image_turbo, z_image_turbo,
18 },
19};
20
21#[non_exhaustive]
22#[allow(non_camel_case_types)]
23#[subenum(
24 Flux1Weight(derive(Default)),
25 Flux1MiniWeight(derive(Default)),
26 ChromaWeight(derive(Default)),
27 NitroSDRealismWeight(derive(Default)),
28 NitroSDVibrantWeight(derive(Default)),
29 DiffInstructStarWeight(derive(Default)),
30 ChromaRadianceWeight(derive(Default)),
31 SSD1BWeight(derive(Default)),
32 Flux2Weight(derive(Default)),
33 ZImageTurboWeight(derive(Default)),
34 QwenImageWeight(derive(Default)),
35 OvisImageWeight(derive(Default)),
36 TwinFlowZImageTurboExpWeight(derive(Default)),
37 Flux2Klein4BWeight(derive(Default)),
38 Flux2KleinBase4BWeight(derive(Default)),
39 Flux2Klein9BWeight(derive(Default)),
40 Flux2KleinBase9BWeight(derive(Default)),
41 AnimaWeight(derive(Default)),
42 Anima2Weight(derive(Default)),
43 SDXS512DreamShaperWeight(derive(Default)),
44 ErnieImageWeight(derive(Default)),
45 LongCatImageWeight(derive(Default))
46)]
47#[derive(Debug, Clone, Copy, EnumString, VariantNames)]
48#[strum(ascii_case_insensitive)]
49pub enum WeightType {
51 #[subenum(Flux1MiniWeight)]
52 F32,
53 #[subenum(
54 NitroSDRealismWeight,
55 NitroSDVibrantWeight,
56 DiffInstructStarWeight,
57 SSD1BWeight,
58 SDXS512DreamShaperWeight(default),
59 ErnieImageWeight
60 )]
61 F16,
62 #[subenum(
63 Flux1Weight,
64 ChromaWeight(default),
65 NitroSDRealismWeight,
66 NitroSDVibrantWeight,
67 DiffInstructStarWeight,
68 Flux2Weight,
69 ZImageTurboWeight,
70 QwenImageWeight,
71 OvisImageWeight(default),
72 TwinFlowZImageTurboExpWeight(default),
73 Flux2Klein4BWeight,
74 Flux2KleinBase4BWeight,
75 Flux2Klein9BWeight(default),
76 Flux2KleinBase9BWeight(default),
77 AnimaWeight,
78 ErnieImageWeight(default),
79 LongCatImageWeight(default)
80 )]
81 Q4_0,
82 #[subenum(
83 Flux2Weight,
84 QwenImageWeight,
85 AnimaWeight,
86 ErnieImageWeight,
87 LongCatImageWeight
88 )]
89 Q4_1,
90 #[subenum(
91 NitroSDRealismWeight,
92 NitroSDVibrantWeight,
93 DiffInstructStarWeight,
94 Flux2Weight,
95 ZImageTurboWeight,
96 QwenImageWeight,
97 TwinFlowZImageTurboExpWeight,
98 AnimaWeight,
99 ErnieImageWeight,
100 LongCatImageWeight
101 )]
102 Q5_0,
103 #[subenum(
104 Flux2Weight,
105 QwenImageWeight,
106 AnimaWeight,
107 ErnieImageWeight,
108 LongCatImageWeight
109 )]
110 Q5_1,
111 #[subenum(
112 Flux1Weight,
113 Flux1MiniWeight(default),
114 ChromaWeight,
115 NitroSDRealismWeight(default),
116 NitroSDVibrantWeight(default),
117 DiffInstructStarWeight(default),
118 ChromaRadianceWeight(default),
119 Flux2Weight,
120 ZImageTurboWeight,
121 QwenImageWeight,
122 OvisImageWeight,
123 TwinFlowZImageTurboExpWeight,
124 Flux2Klein4BWeight(default),
125 Flux2KleinBase4BWeight(default),
126 Flux2Klein9BWeight,
127 AnimaWeight(default),
128 Anima2Weight(default),
129 SDXS512DreamShaperWeight,
130 ErnieImageWeight,
131 LongCatImageWeight
132 )]
133 Q8_0,
134 Q8_1,
135 #[subenum(
136 Flux1Weight(default),
137 Flux1MiniWeight,
138 NitroSDRealismWeight,
139 NitroSDVibrantWeight,
140 DiffInstructStarWeight,
141 Flux2Weight(default),
142 ZImageTurboWeight,
143 QwenImageWeight(default),
144 ErnieImageWeight
145 )]
146 Q2_K,
147 #[subenum(
148 Flux1Weight,
149 Flux1MiniWeight,
150 NitroSDRealismWeight,
151 NitroSDVibrantWeight,
152 DiffInstructStarWeight,
153 ZImageTurboWeight,
154 Flux2Weight,
155 QwenImageWeight,
156 TwinFlowZImageTurboExpWeight,
157 AnimaWeight,
158 ErnieImageWeight,
159 LongCatImageWeight
160 )]
161 Q3_K,
162 #[subenum(
163 Flux1Weight,
164 ZImageTurboWeight(default),
165 Flux2Weight,
166 QwenImageWeight,
167 AnimaWeight,
168 Anima2Weight,
169 ErnieImageWeight,
170 LongCatImageWeight
171 )]
172 Q4_K,
173 #[subenum(
174 Flux1MiniWeight,
175 Flux2Weight,
176 QwenImageWeight,
177 AnimaWeight,
178 Anima2Weight,
179 ErnieImageWeight,
180 LongCatImageWeight
181 )]
182 Q5_K,
183 #[subenum(
184 Flux1MiniWeight,
185 NitroSDRealismWeight,
186 NitroSDVibrantWeight,
187 DiffInstructStarWeight,
188 Flux2Weight,
189 ZImageTurboWeight,
190 QwenImageWeight,
191 TwinFlowZImageTurboExpWeight,
192 AnimaWeight,
193 Anima2Weight,
194 ErnieImageWeight,
195 LongCatImageWeight
196 )]
197 Q6_K,
198 Q8_K,
199 IQ2_XXS,
200 IQ2_XS,
201 IQ3_XXS,
202 IQ1_S,
203 IQ4_NL,
204 IQ3_S,
205 IQ2_S,
206 IQ4_XS,
207 I8,
208 I16,
209 I32,
210 I64,
211 F64,
212 IQ1_M,
213 #[subenum(
214 Flux1MiniWeight,
215 ChromaWeight,
216 ChromaRadianceWeight,
217 Flux2Weight,
218 ZImageTurboWeight,
219 QwenImageWeight,
220 OvisImageWeight,
221 TwinFlowZImageTurboExpWeight,
222 Flux2Klein4BWeight,
223 Flux2KleinBase4BWeight,
224 Flux2Klein9BWeight,
225 Flux2KleinBase9BWeight,
226 AnimaWeight,
227 Anima2Weight,
228 ErnieImageWeight,
229 LongCatImageWeight
230 )]
231 BF16,
232 TQ1_0,
233 TQ2_0,
234 MXFP4,
235 NVFP4,
236 Q1_0,
237 #[subenum(SSD1BWeight(default), QwenImageWeight)]
238 F8_E4M3,
239}
240
241#[non_exhaustive]
242#[derive(Debug, Clone, Copy, EnumDiscriminants)]
243#[strum_discriminants(derive(EnumString, VariantNames), strum(ascii_case_insensitive))]
244pub enum Preset {
246 StableDiffusion1_4,
247 StableDiffusion1_5,
248 StableDiffusion2_1,
251 StableDiffusion3Medium,
254 StableDiffusion3_5Medium,
257 StableDiffusion3_5Large,
260 StableDiffusion3_5LargeTurbo,
263 SDXLBase1_0,
264 SDTurbo,
266 SDXLTurbo1_0,
268 Flux1Dev(Flux1Weight),
271 Flux1Schnell(Flux1Weight),
274 Flux1Mini(Flux1MiniWeight),
277 JuggernautXL11,
280 Chroma(ChromaWeight),
284 NitroSDRealism(NitroSDRealismWeight),
286 NitroSDVibrant(NitroSDVibrantWeight),
288 DiffInstructStar(DiffInstructStarWeight),
290 ChromaRadiance(ChromaRadianceWeight),
292 SSD1B(SSD1BWeight),
294 Flux2Dev(Flux2Weight),
297 ZImageTurbo(ZImageTurboWeight),
300 QwenImage(QwenImageWeight),
302 OvisImage(OvisImageWeight),
305 DreamShaperXL2_1Turbo,
308 TwinFlowZImageTurboExp(TwinFlowZImageTurboExpWeight),
311 SDXS512DreamShaper(SDXS512DreamShaperWeight),
313 Flux2Klein4B(Flux2Klein4BWeight),
316 Flux2KleinBase4B(Flux2KleinBase4BWeight),
319 Flux2Klein9B(Flux2Klein9BWeight),
322 Flux2KleinBase9B(Flux2KleinBase9BWeight),
325 SegmindVega,
327 Anima(AnimaWeight),
329 Anima2(Anima2Weight),
331 ErnieImage(ErnieImageWeight),
333 ErnieImageTurbo(ErnieImageWeight),
335 HiDreamO1ImageDev,
337 HiDreamO1Image,
339 LongCatImage(LongCatImageWeight),
342 Lens,
345 LensTurbo,
348}
349
350impl Preset {
351 fn try_configs_builder(self) -> Result<(ConfigBuilder, ModelConfigBuilder), ApiError> {
352 match self {
353 Preset::StableDiffusion1_4 => stable_diffusion_1_4(),
354 Preset::StableDiffusion1_5 => stable_diffusion_1_5(),
355 Preset::StableDiffusion2_1 => stable_diffusion_2_1(),
356 Preset::StableDiffusion3Medium => stable_diffusion_3_medium(),
357 Preset::SDXLBase1_0 => sdxl_base_1_0(),
358 Preset::Flux1Dev(sd_type_t) => flux_1_dev(sd_type_t),
359 Preset::Flux1Schnell(sd_type_t) => flux_1_schnell(sd_type_t),
360 Preset::SDTurbo => sd_turbo(),
361 Preset::SDXLTurbo1_0 => sdxl_turbo_1_0(),
362 Preset::StableDiffusion3_5Large => stable_diffusion_3_5_large(),
363 Preset::StableDiffusion3_5Medium => stable_diffusion_3_5_medium(),
364 Preset::StableDiffusion3_5LargeTurbo => stable_diffusion_3_5_large_turbo(),
365 Preset::JuggernautXL11 => juggernaut_xl_11(),
366 Preset::Flux1Mini(sd_type_t) => flux_1_mini(sd_type_t),
367 Preset::Chroma(sd_type_t) => chroma(sd_type_t),
368 Preset::NitroSDRealism(sd_type_t) => nitro_sd_realism(sd_type_t),
369 Preset::NitroSDVibrant(sd_type_t) => nitro_sd_vibrant(sd_type_t),
370 Preset::DiffInstructStar(sd_type_t) => diff_instruct_star(sd_type_t),
371 Preset::ChromaRadiance(sd_type_t) => chroma_radiance(sd_type_t),
372 Preset::SSD1B(sd_type_t) => ssd_1b(sd_type_t),
373 Preset::Flux2Dev(sd_type_t) => flux_2_dev(sd_type_t),
374 Preset::ZImageTurbo(sd_type_t) => z_image_turbo(sd_type_t),
375 Preset::QwenImage(sd_type_t) => qwen_image(sd_type_t),
376 Preset::OvisImage(sd_type_t) => ovis_image(sd_type_t),
377 Preset::DreamShaperXL2_1Turbo => dream_shaper_xl_2_1_turbo(),
378 Preset::TwinFlowZImageTurboExp(sd_type_t) => twinflow_z_image_turbo(sd_type_t),
379 Preset::SDXS512DreamShaper(sd_type_t) => sdxs512_dream_shaper(sd_type_t),
380 Preset::Flux2Klein4B(sd_type_t) => flux_2_klein_4b(sd_type_t),
381 Preset::Flux2KleinBase4B(sd_type_t) => flux_2_klein_base_4b(sd_type_t),
382 Preset::Flux2Klein9B(sd_type_t) => flux_2_klein_9b(sd_type_t),
383 Preset::Flux2KleinBase9B(sd_type_t) => flux_2_klein_base_9b(sd_type_t),
384 Preset::SegmindVega => segmind_vega(),
385 Preset::Anima(sd_type_t) => anima(sd_type_t),
386 Preset::Anima2(sd_type_t) => anima2(sd_type_t),
387 Preset::ErnieImage(sd_type_t) => ernie_image(sd_type_t),
388 Preset::ErnieImageTurbo(sd_type_t) => ernie_image_turbo(sd_type_t),
389 Preset::HiDreamO1ImageDev => hi_dream_o1_image_dev(),
390 Preset::HiDreamO1Image => hi_dream_o1_image(),
391 Preset::LongCatImage(sd_type_t) => long_cat_image(sd_type_t),
392 Preset::Lens => lens(),
393 Preset::LensTurbo => lens_turbo(),
394 }
395 }
396}
397
398pub type ConfigsBuilder = (ConfigBuilder, ModelConfigBuilder);
400
401pub type Configs = (Config, ModelConfig);
403
404type ModifierFunction = dyn FnOnce(ConfigsBuilder) -> Result<ConfigsBuilder, ApiError>;
406
407#[derive(Builder)]
408#[builder(
409 name = "PresetBuilder",
410 pattern = "owned",
411 setter(into),
412 build_fn(name = "internal_build", private, error = "ConfigBuilderError")
413)]
414pub struct PresetConfig {
416 prompt: String,
417 preset: Preset,
418 #[builder(private, default = "Vec::new()")]
419 modifiers: Vec<Box<ModifierFunction>>,
420}
421
422impl PresetBuilder {
423 pub fn with_modifier<F>(mut self, f: F) -> Self
425 where
426 F: FnOnce(ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> + 'static,
427 {
428 if self.modifiers.is_none() {
429 self.modifiers = Some(Vec::new());
430 }
431 self.modifiers.as_mut().unwrap().push(Box::new(f));
432 self
433 }
434
435 pub fn build(self) -> Result<Configs, ConfigBuilderError> {
436 let preset = self.internal_build()?;
437 let configs: ConfigsBuilder = preset
438 .try_into()
439 .map_err(|err: ApiError| ConfigBuilderError::ValidationError(err.to_string()))?;
440 let config = configs.0.build()?;
441 let config_model = configs.1.build()?;
442
443 Ok((config, config_model))
444 }
445}
446
447impl TryFrom<PresetConfig> for ConfigsBuilder {
448 type Error = ApiError;
449
450 fn try_from(value: PresetConfig) -> Result<Self, Self::Error> {
451 let mut configs_builder = value.preset.try_configs_builder()?;
452 for m in value.modifiers {
453 configs_builder = m(configs_builder)?;
454 }
455 configs_builder.0.prompt(value.prompt);
456 Ok(configs_builder)
457 }
458}
459
460#[cfg(test)]
461mod tests {
462 use crate::{
463 api::gen_img,
464 preset::{
465 ChromaRadianceWeight, ChromaWeight, DiffInstructStarWeight, Flux1MiniWeight,
466 Flux1Weight, Flux2Klein4BWeight, Flux2Klein9BWeight, Flux2KleinBase4BWeight,
467 Flux2KleinBase9BWeight, Flux2Weight, NitroSDRealismWeight, NitroSDVibrantWeight,
468 OvisImageWeight, QwenImageWeight, SDXS512DreamShaperWeight, SSD1BWeight,
469 TwinFlowZImageTurboExpWeight, ZImageTurboWeight,
470 },
471 util::set_hf_token,
472 };
473
474 use super::{Preset, PresetBuilder};
475 static PROMPT: &str = "a lovely dinosaur made by crochet";
476
477 fn run(preset: Preset) {
478 let (config, mut model_config) = PresetBuilder::default()
479 .preset(preset)
480 .prompt(PROMPT)
481 .build()
482 .unwrap();
483 gen_img(&config, &mut model_config).unwrap();
484 }
485
486 #[ignore]
487 #[test]
488 fn test_stable_diffusion_1_4() {
489 run(Preset::StableDiffusion1_4);
490 }
491
492 #[ignore]
493 #[test]
494 fn test_stable_diffusion_1_5() {
495 run(Preset::StableDiffusion1_5);
496 }
497
498 #[ignore]
499 #[test]
500 fn test_stable_diffusion_2_1() {
501 run(Preset::StableDiffusion2_1);
502 }
503
504 #[ignore]
505 #[test]
506 fn test_stable_diffusion_3_medium_fp16() {
507 set_hf_token(include_str!("../token.txt"));
508 run(Preset::StableDiffusion3Medium);
509 }
510
511 #[ignore]
512 #[test]
513 fn test_sdxl_base_1_0() {
514 run(Preset::SDXLBase1_0);
515 }
516
517 #[ignore]
518 #[test]
519 fn test_flux_1_dev() {
520 set_hf_token(include_str!("../token.txt"));
521 run(Preset::Flux1Dev(Flux1Weight::Q2_K));
522 }
523
524 #[ignore]
525 #[test]
526 fn test_flux_1_schnell() {
527 set_hf_token(include_str!("../token.txt"));
528 run(Preset::Flux1Schnell(Flux1Weight::Q2_K));
529 }
530
531 #[ignore]
532 #[test]
533 fn test_sd_turbo() {
534 run(Preset::SDTurbo);
535 }
536
537 #[ignore]
538 #[test]
539 fn test_sdxl_turbo_1_0_fp16() {
540 run(Preset::SDXLTurbo1_0);
541 }
542
543 #[ignore]
544 #[test]
545 fn test_stable_diffusion_3_5_medium_fp16() {
546 set_hf_token(include_str!("../token.txt"));
547 run(Preset::StableDiffusion3_5Medium);
548 }
549
550 #[ignore]
551 #[test]
552 fn test_stable_diffusion_3_5_large_fp16() {
553 set_hf_token(include_str!("../token.txt"));
554 run(Preset::StableDiffusion3_5Large);
555 }
556
557 #[ignore]
558 #[test]
559 fn test_stable_diffusion_3_5_large_turbo_fp16() {
560 set_hf_token(include_str!("../token.txt"));
561 run(Preset::StableDiffusion3_5LargeTurbo);
562 }
563
564 #[ignore]
565 #[test]
566 fn test_juggernaut_xl_11() {
567 set_hf_token(include_str!("../token.txt"));
568 run(Preset::JuggernautXL11);
569 }
570
571 #[ignore]
572 #[test]
573 fn test_flux_1_mini() {
574 set_hf_token(include_str!("../token.txt"));
575 run(Preset::Flux1Mini(Flux1MiniWeight::Q2_K));
576 }
577
578 #[ignore]
579 #[test]
580 fn test_chroma() {
581 set_hf_token(include_str!("../token.txt"));
582 run(Preset::Chroma(ChromaWeight::Q4_0));
583 }
584
585 #[ignore]
586 #[test]
587 fn test_nitro_sd_realism() {
588 run(Preset::NitroSDRealism(NitroSDRealismWeight::Q8_0));
589 }
590
591 #[ignore]
592 #[test]
593 fn test_nitro_sd_vibrant() {
594 run(Preset::NitroSDVibrant(NitroSDVibrantWeight::Q8_0));
595 }
596
597 #[ignore]
598 #[test]
599 fn test_diff_instruct_star() {
600 run(Preset::DiffInstructStar(DiffInstructStarWeight::Q8_0));
601 }
602
603 #[ignore]
604 #[test]
605 fn test_chroma_radiance() {
606 run(Preset::ChromaRadiance(ChromaRadianceWeight::Q8_0));
607 }
608
609 #[ignore]
610 #[test]
611 fn test_ssd_1b() {
612 run(Preset::SSD1B(SSD1BWeight::F8_E4M3));
613 }
614
615 #[ignore]
616 #[test]
617 fn test_flux_2_dev() {
618 set_hf_token(include_str!("../token.txt"));
619 run(Preset::Flux2Dev(Flux2Weight::Q2_K));
620 }
621
622 #[ignore]
623 #[test]
624 fn test_z_image_turbo() {
625 set_hf_token(include_str!("../token.txt"));
626 run(Preset::ZImageTurbo(ZImageTurboWeight::Q2_K));
627 }
628
629 #[ignore]
630 #[test]
631 fn test_qwen_image() {
632 run(Preset::QwenImage(QwenImageWeight::Q2_K));
633 }
634
635 #[ignore]
636 #[test]
637 fn test_ovis_image() {
638 set_hf_token(include_str!("../token.txt"));
639 run(Preset::OvisImage(OvisImageWeight::Q4_0));
640 }
641
642 #[ignore]
643 #[test]
644 fn test_dreamshaper_xl_2_1_turbo() {
645 run(Preset::DreamShaperXL2_1Turbo);
646 }
647
648 #[ignore]
649 #[test]
650 fn test_twinflow_z_image_turbo_exp() {
651 set_hf_token(include_str!("../token.txt"));
652 run(Preset::TwinFlowZImageTurboExp(
653 TwinFlowZImageTurboExpWeight::Q3_K,
654 ));
655 }
656
657 #[ignore]
658 #[test]
659 fn test_sdxs512_dream_shaper() {
660 run(Preset::SDXS512DreamShaper(SDXS512DreamShaperWeight::Q8_0));
661 }
662
663 #[ignore]
664 #[test]
665 fn test_flux_2_klein_4b() {
666 set_hf_token(include_str!("../token.txt"));
667 run(Preset::Flux2Klein4B(Flux2Klein4BWeight::Q8_0));
668 }
669
670 #[ignore]
671 #[test]
672 fn test_flux_2_klein_base_4b() {
673 set_hf_token(include_str!("../token.txt"));
674 run(Preset::Flux2KleinBase4B(Flux2KleinBase4BWeight::Q8_0));
675 }
676
677 #[ignore]
678 #[test]
679 fn test_flux_2_klein_9b() {
680 set_hf_token(include_str!("../token.txt"));
681 run(Preset::Flux2Klein9B(Flux2Klein9BWeight::Q4_0));
682 }
683
684 #[ignore]
685 #[test]
686 fn test_flux_2_klein_base_9b() {
687 set_hf_token(include_str!("../token.txt"));
688 run(Preset::Flux2KleinBase9B(Flux2KleinBase9BWeight::Q4_0));
689 }
690
691 #[ignore]
692 #[test]
693 fn test_segmind_vega() {
694 run(Preset::SegmindVega);
695 }
696
697 #[ignore]
698 #[test]
699 fn test_anima() {
700 run(Preset::Anima(super::AnimaWeight::Q8_0));
701 }
702
703 #[ignore]
704 #[test]
705 fn test_anima2() {
706 run(Preset::Anima2(super::Anima2Weight::Q8_0));
707 }
708
709 #[ignore]
710 #[test]
711 fn test_ernie_image() {
712 run(Preset::ErnieImage(super::ErnieImageWeight::Q4_0));
713 }
714
715 #[ignore]
716 #[test]
717 fn test_ernie_image_turbo() {
718 run(Preset::ErnieImageTurbo(super::ErnieImageWeight::Q4_0));
719 }
720 #[ignore]
721 #[test]
722 fn test_long_cat_image() {
723 run(Preset::LongCatImage(super::LongCatImageWeight::Q4_0));
724 }
725
726 #[ignore]
727 #[test]
728 fn test_lens() {
729 run(Preset::Lens);
730 }
731
732 #[ignore]
733 #[test]
734 fn test_lens_turbo() {
735 run(Preset::LensTurbo);
736 }
737}