1use derive_builder::Builder;
2use hf_hub::api::sync::ApiError;
3use strum::{EnumDiscriminants, EnumString, VariantNames};
4use subenum::subenum;
5
6use crate::{
7 api::{Config, ConfigBuilder, ConfigBuilderError, ModelConfig, ModelConfigBuilder},
8 preset_builder::{
9 anima, chroma, chroma_radiance, diff_instruct_star, dream_shaper_xl_2_1_turbo, flux_1_dev,
10 flux_1_mini, flux_1_schnell, flux_2_dev, flux_2_klein_4b, flux_2_klein_9b,
11 flux_2_klein_base_4b, flux_2_klein_base_9b, juggernaut_xl_11, nitro_sd_realism,
12 nitro_sd_vibrant, ovis_image, qwen_image, sd_turbo, sdxl_base_1_0, sdxl_turbo_1_0,
13 sdxs512_dream_shaper, segmind_vega, ssd_1b, stable_diffusion_1_4, stable_diffusion_1_5,
14 stable_diffusion_2_1, stable_diffusion_3_5_large, stable_diffusion_3_5_large_turbo,
15 stable_diffusion_3_5_medium, stable_diffusion_3_medium, twinflow_z_image_turbo,
16 z_image_turbo,
17 },
18};
19
20#[non_exhaustive]
21#[allow(non_camel_case_types)]
22#[subenum(
23 Flux1Weight(derive(Default)),
24 Flux1MiniWeight(derive(Default)),
25 ChromaWeight(derive(Default)),
26 NitroSDRealismWeight(derive(Default)),
27 NitroSDVibrantWeight(derive(Default)),
28 DiffInstructStarWeight(derive(Default)),
29 ChromaRadianceWeight(derive(Default)),
30 SSD1BWeight(derive(Default)),
31 Flux2Weight(derive(Default)),
32 ZImageTurboWeight(derive(Default)),
33 QwenImageWeight(derive(Default)),
34 OvisImageWeight(derive(Default)),
35 TwinFlowZImageTurboExpWeight(derive(Default)),
36 Flux2Klein4BWeight(derive(Default)),
37 Flux2KleinBase4BWeight(derive(Default)),
38 Flux2Klein9BWeight(derive(Default)),
39 Flux2KleinBase9BWeight(derive(Default)),
40 AnimaWeight(derive(Default))
41)]
42#[derive(Debug, Clone, Copy, EnumString, VariantNames)]
43#[strum(ascii_case_insensitive)]
44pub enum WeightType {
46 #[subenum(Flux1MiniWeight)]
47 F32,
48 #[subenum(
49 NitroSDRealismWeight,
50 NitroSDVibrantWeight,
51 DiffInstructStarWeight,
52 SSD1BWeight
53 )]
54 F16,
55 #[subenum(
56 Flux1Weight,
57 ChromaWeight(default),
58 NitroSDRealismWeight,
59 NitroSDVibrantWeight,
60 DiffInstructStarWeight,
61 Flux2Weight,
62 ZImageTurboWeight,
63 QwenImageWeight,
64 OvisImageWeight(default),
65 TwinFlowZImageTurboExpWeight(default),
66 Flux2Klein4BWeight,
67 Flux2KleinBase4BWeight,
68 Flux2Klein9BWeight(default),
69 Flux2KleinBase9BWeight(default),
70 AnimaWeight
71 )]
72 Q4_0,
73 #[subenum(Flux2Weight, QwenImageWeight, AnimaWeight)]
74 Q4_1,
75 #[subenum(
76 NitroSDRealismWeight,
77 NitroSDVibrantWeight,
78 DiffInstructStarWeight,
79 Flux2Weight,
80 ZImageTurboWeight,
81 QwenImageWeight,
82 TwinFlowZImageTurboExpWeight,
83 AnimaWeight
84 )]
85 Q5_0,
86 #[subenum(Flux2Weight, QwenImageWeight, AnimaWeight)]
87 Q5_1,
88 #[subenum(
89 Flux1Weight,
90 Flux1MiniWeight(default),
91 ChromaWeight,
92 NitroSDRealismWeight(default),
93 NitroSDVibrantWeight(default),
94 DiffInstructStarWeight(default),
95 ChromaRadianceWeight(default),
96 Flux2Weight,
97 ZImageTurboWeight,
98 QwenImageWeight,
99 OvisImageWeight,
100 TwinFlowZImageTurboExpWeight,
101 Flux2Klein4BWeight(default),
102 Flux2KleinBase4BWeight(default),
103 Flux2Klein9BWeight,
104 AnimaWeight(default)
105 )]
106 Q8_0,
107 Q8_1,
108 #[subenum(
109 Flux1Weight(default),
110 Flux1MiniWeight,
111 NitroSDRealismWeight,
112 NitroSDVibrantWeight,
113 DiffInstructStarWeight,
114 Flux2Weight(default),
115 ZImageTurboWeight,
116 QwenImageWeight(default)
117 )]
118 Q2_K,
119 #[subenum(
120 Flux1Weight,
121 Flux1MiniWeight,
122 NitroSDRealismWeight,
123 NitroSDVibrantWeight,
124 DiffInstructStarWeight,
125 ZImageTurboWeight,
126 Flux2Weight,
127 QwenImageWeight,
128 TwinFlowZImageTurboExpWeight,
129 AnimaWeight
130 )]
131 Q3_K,
132 #[subenum(
133 Flux1Weight,
134 ZImageTurboWeight(default),
135 Flux2Weight,
136 QwenImageWeight,
137 AnimaWeight
138 )]
139 Q4_K,
140 #[subenum(Flux1MiniWeight, Flux2Weight, QwenImageWeight, AnimaWeight)]
141 Q5_K,
142 #[subenum(
143 Flux1MiniWeight,
144 NitroSDRealismWeight,
145 NitroSDVibrantWeight,
146 DiffInstructStarWeight,
147 Flux2Weight,
148 ZImageTurboWeight,
149 QwenImageWeight,
150 TwinFlowZImageTurboExpWeight,
151 AnimaWeight
152 )]
153 Q6_K,
154 Q8_K,
155 IQ2_XXS,
156 IQ2_XS,
157 IQ3_XXS,
158 IQ1_S,
159 IQ4_NL,
160 IQ3_S,
161 IQ2_S,
162 IQ4_XS,
163 I8,
164 I16,
165 I32,
166 I64,
167 F64,
168 IQ1_M,
169 #[subenum(
170 Flux1MiniWeight,
171 ChromaWeight,
172 ChromaRadianceWeight,
173 Flux2Weight,
174 ZImageTurboWeight,
175 QwenImageWeight,
176 OvisImageWeight,
177 TwinFlowZImageTurboExpWeight,
178 Flux2Klein4BWeight,
179 Flux2KleinBase4BWeight,
180 Flux2Klein9BWeight,
181 Flux2KleinBase9BWeight,
182 AnimaWeight
183 )]
184 BF16,
185 TQ1_0,
186 TQ2_0,
187 MXFP4,
188 #[subenum(SSD1BWeight(default), QwenImageWeight)]
189 F8_E4M3,
190}
191
192#[non_exhaustive]
193#[derive(Debug, Clone, Copy, EnumDiscriminants)]
194#[strum_discriminants(derive(EnumString, VariantNames), strum(ascii_case_insensitive))]
195pub enum Preset {
197 StableDiffusion1_4,
198 StableDiffusion1_5,
199 StableDiffusion2_1,
202 StableDiffusion3Medium,
205 StableDiffusion3_5Medium,
208 StableDiffusion3_5Large,
211 StableDiffusion3_5LargeTurbo,
214 SDXLBase1_0,
215 SDTurbo,
217 SDXLTurbo1_0,
219 Flux1Dev(Flux1Weight),
222 Flux1Schnell(Flux1Weight),
225 Flux1Mini(Flux1MiniWeight),
228 JuggernautXL11,
231 Chroma(ChromaWeight),
235 NitroSDRealism(NitroSDRealismWeight),
237 NitroSDVibrant(NitroSDVibrantWeight),
239 DiffInstructStar(DiffInstructStarWeight),
241 ChromaRadiance(ChromaRadianceWeight),
243 SSD1B(SSD1BWeight),
245 Flux2Dev(Flux2Weight),
248 ZImageTurbo(ZImageTurboWeight),
251 QwenImage(QwenImageWeight),
253 OvisImage(OvisImageWeight),
256 DreamShaperXL2_1Turbo,
259 TwinFlowZImageTurboExp(TwinFlowZImageTurboExpWeight),
262 SDXS512DreamShaper,
264 Flux2Klein4B(Flux2Klein4BWeight),
267 Flux2KleinBase4B(Flux2KleinBase4BWeight),
270 Flux2Klein9B(Flux2Klein9BWeight),
273 Flux2KleinBase9B(Flux2KleinBase9BWeight),
276 SegmindVega,
278 Anima(AnimaWeight),
280}
281
282impl Preset {
283 fn try_configs_builder(self) -> Result<(ConfigBuilder, ModelConfigBuilder), ApiError> {
284 match self {
285 Preset::StableDiffusion1_4 => stable_diffusion_1_4(),
286 Preset::StableDiffusion1_5 => stable_diffusion_1_5(),
287 Preset::StableDiffusion2_1 => stable_diffusion_2_1(),
288 Preset::StableDiffusion3Medium => stable_diffusion_3_medium(),
289 Preset::SDXLBase1_0 => sdxl_base_1_0(),
290 Preset::Flux1Dev(sd_type_t) => flux_1_dev(sd_type_t),
291 Preset::Flux1Schnell(sd_type_t) => flux_1_schnell(sd_type_t),
292 Preset::SDTurbo => sd_turbo(),
293 Preset::SDXLTurbo1_0 => sdxl_turbo_1_0(),
294 Preset::StableDiffusion3_5Large => stable_diffusion_3_5_large(),
295 Preset::StableDiffusion3_5Medium => stable_diffusion_3_5_medium(),
296 Preset::StableDiffusion3_5LargeTurbo => stable_diffusion_3_5_large_turbo(),
297 Preset::JuggernautXL11 => juggernaut_xl_11(),
298 Preset::Flux1Mini(sd_type_t) => flux_1_mini(sd_type_t),
299 Preset::Chroma(sd_type_t) => chroma(sd_type_t),
300 Preset::NitroSDRealism(sd_type_t) => nitro_sd_realism(sd_type_t),
301 Preset::NitroSDVibrant(sd_type_t) => nitro_sd_vibrant(sd_type_t),
302 Preset::DiffInstructStar(sd_type_t) => diff_instruct_star(sd_type_t),
303 Preset::ChromaRadiance(sd_type_t) => chroma_radiance(sd_type_t),
304 Preset::SSD1B(sd_type_t) => ssd_1b(sd_type_t),
305 Preset::Flux2Dev(sd_type_t) => flux_2_dev(sd_type_t),
306 Preset::ZImageTurbo(sd_type_t) => z_image_turbo(sd_type_t),
307 Preset::QwenImage(sd_type_t) => qwen_image(sd_type_t),
308 Preset::OvisImage(sd_type_t) => ovis_image(sd_type_t),
309 Preset::DreamShaperXL2_1Turbo => dream_shaper_xl_2_1_turbo(),
310 Preset::TwinFlowZImageTurboExp(sd_type_t) => twinflow_z_image_turbo(sd_type_t),
311 Preset::SDXS512DreamShaper => sdxs512_dream_shaper(),
312 Preset::Flux2Klein4B(sd_type_t) => flux_2_klein_4b(sd_type_t),
313 Preset::Flux2KleinBase4B(sd_type_t) => flux_2_klein_base_4b(sd_type_t),
314 Preset::Flux2Klein9B(sd_type_t) => flux_2_klein_9b(sd_type_t),
315 Preset::Flux2KleinBase9B(sd_type_t) => flux_2_klein_base_9b(sd_type_t),
316 Preset::SegmindVega => segmind_vega(),
317 Preset::Anima(sd_type_t) => anima(sd_type_t),
318 }
319 }
320}
321
322pub type ConfigsBuilder = (ConfigBuilder, ModelConfigBuilder);
324
325pub type Configs = (Config, ModelConfig);
327
328type ModifierFunction = dyn FnOnce(ConfigsBuilder) -> Result<ConfigsBuilder, ApiError>;
330
331#[derive(Builder)]
332#[builder(
333 name = "PresetBuilder",
334 pattern = "owned",
335 setter(into),
336 build_fn(name = "internal_build", private, error = "ConfigBuilderError")
337)]
338pub struct PresetConfig {
340 prompt: String,
341 preset: Preset,
342 #[builder(private, default = "Vec::new()")]
343 modifiers: Vec<Box<ModifierFunction>>,
344}
345
346impl PresetBuilder {
347 pub fn with_modifier<F>(mut self, f: F) -> Self
349 where
350 F: FnOnce(ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> + 'static,
351 {
352 if self.modifiers.is_none() {
353 self.modifiers = Some(Vec::new());
354 }
355 self.modifiers.as_mut().unwrap().push(Box::new(f));
356 self
357 }
358
359 pub fn build(self) -> Result<Configs, ConfigBuilderError> {
360 let preset = self.internal_build()?;
361 let configs: ConfigsBuilder = preset
362 .try_into()
363 .map_err(|err: ApiError| ConfigBuilderError::ValidationError(err.to_string()))?;
364 let config = configs.0.build()?;
365 let config_model = configs.1.build()?;
366
367 Ok((config, config_model))
368 }
369}
370
371impl TryFrom<PresetConfig> for ConfigsBuilder {
372 type Error = ApiError;
373
374 fn try_from(value: PresetConfig) -> Result<Self, Self::Error> {
375 let mut configs_builder = value.preset.try_configs_builder()?;
376 for m in value.modifiers {
377 configs_builder = m(configs_builder)?;
378 }
379 configs_builder.0.prompt(value.prompt);
380 Ok(configs_builder)
381 }
382}
383
384#[cfg(test)]
385mod tests {
386 use crate::{
387 api::gen_img,
388 preset::{
389 ChromaRadianceWeight, ChromaWeight, DiffInstructStarWeight, Flux1MiniWeight,
390 Flux1Weight, Flux2Klein4BWeight, Flux2Klein9BWeight, Flux2KleinBase4BWeight,
391 Flux2KleinBase9BWeight, Flux2Weight, NitroSDRealismWeight, NitroSDVibrantWeight,
392 OvisImageWeight, QwenImageWeight, SSD1BWeight, TwinFlowZImageTurboExpWeight,
393 ZImageTurboWeight,
394 },
395 util::set_hf_token,
396 };
397
398 use super::{Preset, PresetBuilder};
399 static PROMPT: &str = "a lovely dinosaur made by crochet";
400
401 fn run(preset: Preset) {
402 let (config, mut model_config) = PresetBuilder::default()
403 .preset(preset)
404 .prompt(PROMPT)
405 .build()
406 .unwrap();
407 gen_img(&config, &mut model_config).unwrap();
408 }
409
410 #[ignore]
411 #[test]
412 fn test_stable_diffusion_1_4() {
413 run(Preset::StableDiffusion1_4);
414 }
415
416 #[ignore]
417 #[test]
418 fn test_stable_diffusion_1_5() {
419 run(Preset::StableDiffusion1_5);
420 }
421
422 #[ignore]
423 #[test]
424 fn test_stable_diffusion_2_1() {
425 run(Preset::StableDiffusion2_1);
426 }
427
428 #[ignore]
429 #[test]
430 fn test_stable_diffusion_3_medium_fp16() {
431 set_hf_token(include_str!("../token.txt"));
432 run(Preset::StableDiffusion3Medium);
433 }
434
435 #[ignore]
436 #[test]
437 fn test_sdxl_base_1_0() {
438 run(Preset::SDXLBase1_0);
439 }
440
441 #[ignore]
442 #[test]
443 fn test_flux_1_dev() {
444 set_hf_token(include_str!("../token.txt"));
445 run(Preset::Flux1Dev(Flux1Weight::Q2_K));
446 }
447
448 #[ignore]
449 #[test]
450 fn test_flux_1_schnell() {
451 set_hf_token(include_str!("../token.txt"));
452 run(Preset::Flux1Schnell(Flux1Weight::Q2_K));
453 }
454
455 #[ignore]
456 #[test]
457 fn test_sd_turbo() {
458 run(Preset::SDTurbo);
459 }
460
461 #[ignore]
462 #[test]
463 fn test_sdxl_turbo_1_0_fp16() {
464 run(Preset::SDXLTurbo1_0);
465 }
466
467 #[ignore]
468 #[test]
469 fn test_stable_diffusion_3_5_medium_fp16() {
470 set_hf_token(include_str!("../token.txt"));
471 run(Preset::StableDiffusion3_5Medium);
472 }
473
474 #[ignore]
475 #[test]
476 fn test_stable_diffusion_3_5_large_fp16() {
477 set_hf_token(include_str!("../token.txt"));
478 run(Preset::StableDiffusion3_5Large);
479 }
480
481 #[ignore]
482 #[test]
483 fn test_stable_diffusion_3_5_large_turbo_fp16() {
484 set_hf_token(include_str!("../token.txt"));
485 run(Preset::StableDiffusion3_5LargeTurbo);
486 }
487
488 #[ignore]
489 #[test]
490 fn test_juggernaut_xl_11() {
491 set_hf_token(include_str!("../token.txt"));
492 run(Preset::JuggernautXL11);
493 }
494
495 #[ignore]
496 #[test]
497 fn test_flux_1_mini() {
498 set_hf_token(include_str!("../token.txt"));
499 run(Preset::Flux1Mini(Flux1MiniWeight::Q2_K));
500 }
501
502 #[ignore]
503 #[test]
504 fn test_chroma() {
505 set_hf_token(include_str!("../token.txt"));
506 run(Preset::Chroma(ChromaWeight::Q4_0));
507 }
508
509 #[ignore]
510 #[test]
511 fn test_nitro_sd_realism() {
512 run(Preset::NitroSDRealism(NitroSDRealismWeight::Q8_0));
513 }
514
515 #[ignore]
516 #[test]
517 fn test_nitro_sd_vibrant() {
518 run(Preset::NitroSDVibrant(NitroSDVibrantWeight::Q8_0));
519 }
520
521 #[ignore]
522 #[test]
523 fn test_diff_instruct_star() {
524 run(Preset::DiffInstructStar(DiffInstructStarWeight::Q8_0));
525 }
526
527 #[ignore]
528 #[test]
529 fn test_chroma_radiance() {
530 run(Preset::ChromaRadiance(ChromaRadianceWeight::Q8_0));
531 }
532
533 #[ignore]
534 #[test]
535 fn test_ssd_1b() {
536 run(Preset::SSD1B(SSD1BWeight::F8_E4M3));
537 }
538
539 #[ignore]
540 #[test]
541 fn test_flux_2_dev() {
542 set_hf_token(include_str!("../token.txt"));
543 run(Preset::Flux2Dev(Flux2Weight::Q2_K));
544 }
545
546 #[ignore]
547 #[test]
548 fn test_z_image_turbo() {
549 set_hf_token(include_str!("../token.txt"));
550 run(Preset::ZImageTurbo(ZImageTurboWeight::Q2_K));
551 }
552
553 #[ignore]
554 #[test]
555 fn test_qwen_image() {
556 run(Preset::QwenImage(QwenImageWeight::Q2_K));
557 }
558
559 #[ignore]
560 #[test]
561 fn test_ovis_image() {
562 set_hf_token(include_str!("../token.txt"));
563 run(Preset::OvisImage(OvisImageWeight::Q4_0));
564 }
565
566 #[ignore]
567 #[test]
568 fn test_dreamshaper_xl_2_1_turbo() {
569 run(Preset::DreamShaperXL2_1Turbo);
570 }
571
572 #[ignore]
573 #[test]
574 fn test_twinflow_z_image_turbo_exp() {
575 set_hf_token(include_str!("../token.txt"));
576 run(Preset::TwinFlowZImageTurboExp(
577 TwinFlowZImageTurboExpWeight::Q3_K,
578 ));
579 }
580
581 #[ignore]
582 #[test]
583 fn test_sdxs512_dream_shaper() {
584 run(Preset::SDXS512DreamShaper);
585 }
586
587 #[ignore]
588 #[test]
589 fn test_flux_2_klein_4b() {
590 set_hf_token(include_str!("../token.txt"));
591 run(Preset::Flux2Klein4B(Flux2Klein4BWeight::Q8_0));
592 }
593
594 #[ignore]
595 #[test]
596 fn test_flux_2_klein_base_4b() {
597 set_hf_token(include_str!("../token.txt"));
598 run(Preset::Flux2KleinBase4B(Flux2KleinBase4BWeight::Q8_0));
599 }
600
601 #[ignore]
602 #[test]
603 fn test_flux_2_klein_9b() {
604 set_hf_token(include_str!("../token.txt"));
605 run(Preset::Flux2Klein9B(Flux2Klein9BWeight::Q4_0));
606 }
607
608 #[ignore]
609 #[test]
610 fn test_flux_2_klein_base_9b() {
611 set_hf_token(include_str!("../token.txt"));
612 run(Preset::Flux2KleinBase9B(Flux2KleinBase9BWeight::Q4_0));
613 }
614
615 #[ignore]
616 #[test]
617 fn test_segmind_vega() {
618 run(Preset::SegmindVega);
619 }
620
621 #[ignore]
622 #[test]
623 fn test_anima() {
624 run(Preset::Anima(super::AnimaWeight::Q8_0));
625 }
626}