1use derive_builder::Builder;
2use hf_hub::api::sync::ApiError;
3use strum::{EnumDiscriminants, EnumString, VariantNames};
4use subenum::subenum;
5
6use crate::{
7 api::{Config, ConfigBuilder, ConfigBuilderError, ModelConfig, ModelConfigBuilder},
8 preset_builder::{
9 chroma, chroma_radiance, diff_instruct_star, dream_shaper_xl_2_1_turbo, flux_1_dev,
10 flux_1_mini, flux_1_schnell, flux_2_dev, flux_2_klein_4b, flux_2_klein_9b,
11 flux_2_klein_base_4b, flux_2_klein_base_9b, juggernaut_xl_11, nitro_sd_realism,
12 nitro_sd_vibrant, ovis_image, qwen_image, sd_turbo, sdxl_base_1_0, sdxl_turbo_1_0,
13 sdxs512_dream_shaper, segmind_vega, ssd_1b, stable_diffusion_1_4, stable_diffusion_1_5,
14 stable_diffusion_2_1, stable_diffusion_3_5_large, stable_diffusion_3_5_large_turbo,
15 stable_diffusion_3_5_medium, stable_diffusion_3_medium, twinflow_z_image_turbo,
16 z_image_turbo,
17 },
18};
19
20#[non_exhaustive]
21#[allow(non_camel_case_types)]
22#[subenum(
23 Flux1Weight(derive(Default)),
24 Flux1MiniWeight(derive(Default)),
25 ChromaWeight(derive(Default)),
26 NitroSDRealismWeight(derive(Default)),
27 NitroSDVibrantWeight(derive(Default)),
28 DiffInstructStarWeight(derive(Default)),
29 ChromaRadianceWeight(derive(Default)),
30 SSD1BWeight(derive(Default)),
31 Flux2Weight(derive(Default)),
32 ZImageTurboWeight(derive(Default)),
33 QwenImageWeight(derive(Default)),
34 OvisImageWeight(derive(Default)),
35 TwinFlowZImageTurboExpWeight(derive(Default)),
36 Flux2Klein4BWeight(derive(Default)),
37 Flux2KleinBase4BWeight(derive(Default)),
38 Flux2Klein9BWeight(derive(Default)),
39 Flux2KleinBase9BWeight(derive(Default))
40)]
41#[derive(Debug, Clone, Copy, EnumString, VariantNames)]
42#[strum(ascii_case_insensitive)]
43pub enum WeightType {
45 #[subenum(Flux1MiniWeight)]
46 F32,
47 #[subenum(
48 NitroSDRealismWeight,
49 NitroSDVibrantWeight,
50 DiffInstructStarWeight,
51 SSD1BWeight
52 )]
53 F16,
54 #[subenum(
55 Flux1Weight,
56 ChromaWeight(default),
57 NitroSDRealismWeight,
58 NitroSDVibrantWeight,
59 DiffInstructStarWeight,
60 Flux2Weight,
61 ZImageTurboWeight,
62 QwenImageWeight,
63 OvisImageWeight(default),
64 TwinFlowZImageTurboExpWeight(default),
65 Flux2Klein4BWeight,
66 Flux2KleinBase4BWeight,
67 Flux2Klein9BWeight(default),
68 Flux2KleinBase9BWeight(default)
69 )]
70 Q4_0,
71 #[subenum(Flux2Weight, QwenImageWeight)]
72 Q4_1,
73 #[subenum(
74 NitroSDRealismWeight,
75 NitroSDVibrantWeight,
76 DiffInstructStarWeight,
77 Flux2Weight,
78 ZImageTurboWeight,
79 QwenImageWeight,
80 TwinFlowZImageTurboExpWeight
81 )]
82 Q5_0,
83 #[subenum(Flux2Weight, QwenImageWeight)]
84 Q5_1,
85 #[subenum(
86 Flux1Weight,
87 Flux1MiniWeight(default),
88 ChromaWeight,
89 NitroSDRealismWeight(default),
90 NitroSDVibrantWeight(default),
91 DiffInstructStarWeight(default),
92 ChromaRadianceWeight(default),
93 Flux2Weight,
94 ZImageTurboWeight,
95 QwenImageWeight,
96 OvisImageWeight,
97 TwinFlowZImageTurboExpWeight,
98 Flux2Klein4BWeight(default),
99 Flux2KleinBase4BWeight(default),
100 Flux2Klein9BWeight
101 )]
102 Q8_0,
103 Q8_1,
104 #[subenum(
105 Flux1Weight(default),
106 Flux1MiniWeight,
107 NitroSDRealismWeight,
108 NitroSDVibrantWeight,
109 DiffInstructStarWeight,
110 Flux2Weight(default),
111 ZImageTurboWeight,
112 QwenImageWeight(default)
113 )]
114 Q2_K,
115 #[subenum(
116 Flux1Weight,
117 Flux1MiniWeight,
118 NitroSDRealismWeight,
119 NitroSDVibrantWeight,
120 DiffInstructStarWeight,
121 ZImageTurboWeight,
122 Flux2Weight,
123 QwenImageWeight,
124 TwinFlowZImageTurboExpWeight
125 )]
126 Q3_K,
127 #[subenum(Flux1Weight, ZImageTurboWeight(default), Flux2Weight, QwenImageWeight)]
128 Q4_K,
129 #[subenum(Flux1MiniWeight, Flux2Weight, QwenImageWeight)]
130 Q5_K,
131 #[subenum(
132 Flux1MiniWeight,
133 NitroSDRealismWeight,
134 NitroSDVibrantWeight,
135 DiffInstructStarWeight,
136 Flux2Weight,
137 ZImageTurboWeight,
138 QwenImageWeight,
139 TwinFlowZImageTurboExpWeight
140 )]
141 Q6_K,
142 Q8_K,
143 IQ2_XXS,
144 IQ2_XS,
145 IQ3_XXS,
146 IQ1_S,
147 IQ4_NL,
148 IQ3_S,
149 IQ2_S,
150 IQ4_XS,
151 I8,
152 I16,
153 I32,
154 I64,
155 F64,
156 IQ1_M,
157 #[subenum(
158 Flux1MiniWeight,
159 ChromaWeight,
160 ChromaRadianceWeight,
161 Flux2Weight,
162 ZImageTurboWeight,
163 QwenImageWeight,
164 OvisImageWeight,
165 TwinFlowZImageTurboExpWeight,
166 Flux2Klein4BWeight,
167 Flux2KleinBase4BWeight,
168 Flux2Klein9BWeight,
169 Flux2KleinBase9BWeight
170 )]
171 BF16,
172 TQ1_0,
173 TQ2_0,
174 MXFP4,
175 #[subenum(SSD1BWeight(default), QwenImageWeight)]
176 F8_E4M3,
177}
178
179#[non_exhaustive]
180#[derive(Debug, Clone, Copy, EnumDiscriminants)]
181#[strum_discriminants(derive(EnumString, VariantNames), strum(ascii_case_insensitive))]
182pub enum Preset {
184 StableDiffusion1_4,
185 StableDiffusion1_5,
186 StableDiffusion2_1,
189 StableDiffusion3Medium,
192 StableDiffusion3_5Medium,
195 StableDiffusion3_5Large,
198 StableDiffusion3_5LargeTurbo,
201 SDXLBase1_0,
202 SDTurbo,
204 SDXLTurbo1_0,
206 Flux1Dev(Flux1Weight),
209 Flux1Schnell(Flux1Weight),
212 Flux1Mini(Flux1MiniWeight),
215 JuggernautXL11,
218 Chroma(ChromaWeight),
222 NitroSDRealism(NitroSDRealismWeight),
224 NitroSDVibrant(NitroSDVibrantWeight),
226 DiffInstructStar(DiffInstructStarWeight),
228 ChromaRadiance(ChromaRadianceWeight),
230 SSD1B(SSD1BWeight),
232 Flux2Dev(Flux2Weight),
235 ZImageTurbo(ZImageTurboWeight),
238 QwenImage(QwenImageWeight),
240 OvisImage(OvisImageWeight),
243 DreamShaperXL2_1Turbo,
246 TwinFlowZImageTurboExp(TwinFlowZImageTurboExpWeight),
249 SDXS512DreamShaper,
251 Flux2Klein4B(Flux2Klein4BWeight),
254 Flux2KleinBase4B(Flux2KleinBase4BWeight),
257 Flux2Klein9B(Flux2Klein9BWeight),
260 Flux2KleinBase9B(Flux2KleinBase9BWeight),
263 SegmindVega,
265}
266
267impl Preset {
268 fn try_configs_builder(self) -> Result<(ConfigBuilder, ModelConfigBuilder), ApiError> {
269 match self {
270 Preset::StableDiffusion1_4 => stable_diffusion_1_4(),
271 Preset::StableDiffusion1_5 => stable_diffusion_1_5(),
272 Preset::StableDiffusion2_1 => stable_diffusion_2_1(),
273 Preset::StableDiffusion3Medium => stable_diffusion_3_medium(),
274 Preset::SDXLBase1_0 => sdxl_base_1_0(),
275 Preset::Flux1Dev(sd_type_t) => flux_1_dev(sd_type_t),
276 Preset::Flux1Schnell(sd_type_t) => flux_1_schnell(sd_type_t),
277 Preset::SDTurbo => sd_turbo(),
278 Preset::SDXLTurbo1_0 => sdxl_turbo_1_0(),
279 Preset::StableDiffusion3_5Large => stable_diffusion_3_5_large(),
280 Preset::StableDiffusion3_5Medium => stable_diffusion_3_5_medium(),
281 Preset::StableDiffusion3_5LargeTurbo => stable_diffusion_3_5_large_turbo(),
282 Preset::JuggernautXL11 => juggernaut_xl_11(),
283 Preset::Flux1Mini(sd_type_t) => flux_1_mini(sd_type_t),
284 Preset::Chroma(sd_type_t) => chroma(sd_type_t),
285 Preset::NitroSDRealism(sd_type_t) => nitro_sd_realism(sd_type_t),
286 Preset::NitroSDVibrant(sd_type_t) => nitro_sd_vibrant(sd_type_t),
287 Preset::DiffInstructStar(sd_type_t) => diff_instruct_star(sd_type_t),
288 Preset::ChromaRadiance(sd_type_t) => chroma_radiance(sd_type_t),
289 Preset::SSD1B(sd_type_t) => ssd_1b(sd_type_t),
290 Preset::Flux2Dev(sd_type_t) => flux_2_dev(sd_type_t),
291 Preset::ZImageTurbo(sd_type_t) => z_image_turbo(sd_type_t),
292 Preset::QwenImage(sd_type_t) => qwen_image(sd_type_t),
293 Preset::OvisImage(sd_type_t) => ovis_image(sd_type_t),
294 Preset::DreamShaperXL2_1Turbo => dream_shaper_xl_2_1_turbo(),
295 Preset::TwinFlowZImageTurboExp(sd_type_t) => twinflow_z_image_turbo(sd_type_t),
296 Preset::SDXS512DreamShaper => sdxs512_dream_shaper(),
297 Preset::Flux2Klein4B(sd_type_t) => flux_2_klein_4b(sd_type_t),
298 Preset::Flux2KleinBase4B(sd_type_t) => flux_2_klein_base_4b(sd_type_t),
299 Preset::Flux2Klein9B(sd_type_t) => flux_2_klein_9b(sd_type_t),
300 Preset::Flux2KleinBase9B(sd_type_t) => flux_2_klein_base_9b(sd_type_t),
301 Preset::SegmindVega => segmind_vega(),
302 }
303 }
304}
305
306pub type ConfigsBuilder = (ConfigBuilder, ModelConfigBuilder);
308
309pub type Configs = (Config, ModelConfig);
311
312type ModifierFunction = dyn FnOnce(ConfigsBuilder) -> Result<ConfigsBuilder, ApiError>;
314
315#[derive(Builder)]
316#[builder(
317 name = "PresetBuilder",
318 pattern = "owned",
319 setter(into),
320 build_fn(name = "internal_build", private, error = "ConfigBuilderError")
321)]
322pub struct PresetConfig {
324 prompt: String,
325 preset: Preset,
326 #[builder(private, default = "Vec::new()")]
327 modifiers: Vec<Box<ModifierFunction>>,
328}
329
330impl PresetBuilder {
331 pub fn with_modifier<F>(mut self, f: F) -> Self
333 where
334 F: FnOnce(ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> + 'static,
335 {
336 if self.modifiers.is_none() {
337 self.modifiers = Some(Vec::new());
338 }
339 self.modifiers.as_mut().unwrap().push(Box::new(f));
340 self
341 }
342
343 pub fn build(self) -> Result<Configs, ConfigBuilderError> {
344 let preset = self.internal_build()?;
345 let configs: ConfigsBuilder = preset
346 .try_into()
347 .map_err(|err: ApiError| ConfigBuilderError::ValidationError(err.to_string()))?;
348 let config = configs.0.build()?;
349 let config_model = configs.1.build()?;
350
351 Ok((config, config_model))
352 }
353}
354
355impl TryFrom<PresetConfig> for ConfigsBuilder {
356 type Error = ApiError;
357
358 fn try_from(value: PresetConfig) -> Result<Self, Self::Error> {
359 let mut configs_builder = value.preset.try_configs_builder()?;
360 for m in value.modifiers {
361 configs_builder = m(configs_builder)?;
362 }
363 configs_builder.0.prompt(value.prompt);
364 Ok(configs_builder)
365 }
366}
367
368#[cfg(test)]
369mod tests {
370 use crate::{
371 api::gen_img,
372 preset::{
373 ChromaRadianceWeight, ChromaWeight, DiffInstructStarWeight, Flux1MiniWeight,
374 Flux1Weight, Flux2Klein4BWeight, Flux2Klein9BWeight, Flux2KleinBase4BWeight,
375 Flux2KleinBase9BWeight, Flux2Weight, NitroSDRealismWeight, NitroSDVibrantWeight,
376 OvisImageWeight, QwenImageWeight, SSD1BWeight, TwinFlowZImageTurboExpWeight,
377 ZImageTurboWeight,
378 },
379 util::set_hf_token,
380 };
381
382 use super::{Preset, PresetBuilder};
383 static PROMPT: &str = "a lovely dinosaur made by crochet";
384
385 fn run(preset: Preset) {
386 let (config, mut model_config) = PresetBuilder::default()
387 .preset(preset)
388 .prompt(PROMPT)
389 .build()
390 .unwrap();
391 gen_img(&config, &mut model_config).unwrap();
392 }
393
394 #[ignore]
395 #[test]
396 fn test_stable_diffusion_1_4() {
397 run(Preset::StableDiffusion1_4);
398 }
399
400 #[ignore]
401 #[test]
402 fn test_stable_diffusion_1_5() {
403 run(Preset::StableDiffusion1_5);
404 }
405
406 #[ignore]
407 #[test]
408 fn test_stable_diffusion_2_1() {
409 run(Preset::StableDiffusion2_1);
410 }
411
412 #[ignore]
413 #[test]
414 fn test_stable_diffusion_3_medium_fp16() {
415 set_hf_token(include_str!("../token.txt"));
416 run(Preset::StableDiffusion3Medium);
417 }
418
419 #[ignore]
420 #[test]
421 fn test_sdxl_base_1_0() {
422 run(Preset::SDXLBase1_0);
423 }
424
425 #[ignore]
426 #[test]
427 fn test_flux_1_dev() {
428 set_hf_token(include_str!("../token.txt"));
429 run(Preset::Flux1Dev(Flux1Weight::Q2_K));
430 }
431
432 #[ignore]
433 #[test]
434 fn test_flux_1_schnell() {
435 set_hf_token(include_str!("../token.txt"));
436 run(Preset::Flux1Schnell(Flux1Weight::Q2_K));
437 }
438
439 #[ignore]
440 #[test]
441 fn test_sd_turbo() {
442 run(Preset::SDTurbo);
443 }
444
445 #[ignore]
446 #[test]
447 fn test_sdxl_turbo_1_0_fp16() {
448 run(Preset::SDXLTurbo1_0);
449 }
450
451 #[ignore]
452 #[test]
453 fn test_stable_diffusion_3_5_medium_fp16() {
454 set_hf_token(include_str!("../token.txt"));
455 run(Preset::StableDiffusion3_5Medium);
456 }
457
458 #[ignore]
459 #[test]
460 fn test_stable_diffusion_3_5_large_fp16() {
461 set_hf_token(include_str!("../token.txt"));
462 run(Preset::StableDiffusion3_5Large);
463 }
464
465 #[ignore]
466 #[test]
467 fn test_stable_diffusion_3_5_large_turbo_fp16() {
468 set_hf_token(include_str!("../token.txt"));
469 run(Preset::StableDiffusion3_5LargeTurbo);
470 }
471
472 #[ignore]
473 #[test]
474 fn test_juggernaut_xl_11() {
475 set_hf_token(include_str!("../token.txt"));
476 run(Preset::JuggernautXL11);
477 }
478
479 #[ignore]
480 #[test]
481 fn test_flux_1_mini() {
482 set_hf_token(include_str!("../token.txt"));
483 run(Preset::Flux1Mini(Flux1MiniWeight::Q2_K));
484 }
485
486 #[ignore]
487 #[test]
488 fn test_chroma() {
489 set_hf_token(include_str!("../token.txt"));
490 run(Preset::Chroma(ChromaWeight::Q4_0));
491 }
492
493 #[ignore]
494 #[test]
495 fn test_nitro_sd_realism() {
496 run(Preset::NitroSDRealism(NitroSDRealismWeight::Q8_0));
497 }
498
499 #[ignore]
500 #[test]
501 fn test_nitro_sd_vibrant() {
502 run(Preset::NitroSDVibrant(NitroSDVibrantWeight::Q8_0));
503 }
504
505 #[ignore]
506 #[test]
507 fn test_diff_instruct_star() {
508 run(Preset::DiffInstructStar(DiffInstructStarWeight::Q8_0));
509 }
510
511 #[ignore]
512 #[test]
513 fn test_chroma_radiance() {
514 run(Preset::ChromaRadiance(ChromaRadianceWeight::Q8_0));
515 }
516
517 #[ignore]
518 #[test]
519 fn test_ssd_1b() {
520 run(Preset::SSD1B(SSD1BWeight::F8_E4M3));
521 }
522
523 #[ignore]
524 #[test]
525 fn test_flux_2_dev() {
526 set_hf_token(include_str!("../token.txt"));
527 run(Preset::Flux2Dev(Flux2Weight::Q2_K));
528 }
529
530 #[ignore]
531 #[test]
532 fn test_z_image_turbo() {
533 set_hf_token(include_str!("../token.txt"));
534 run(Preset::ZImageTurbo(ZImageTurboWeight::Q2_K));
535 }
536
537 #[ignore]
538 #[test]
539 fn test_qwen_image() {
540 run(Preset::QwenImage(QwenImageWeight::Q2_K));
541 }
542
543 #[ignore]
544 #[test]
545 fn test_ovis_image() {
546 set_hf_token(include_str!("../token.txt"));
547 run(Preset::OvisImage(OvisImageWeight::Q4_0));
548 }
549
550 #[ignore]
551 #[test]
552 fn test_dreamshaper_xl_2_1_turbo() {
553 run(Preset::DreamShaperXL2_1Turbo);
554 }
555
556 #[ignore]
557 #[test]
558 fn test_twinflow_z_image_turbo_exp() {
559 set_hf_token(include_str!("../token.txt"));
560 run(Preset::TwinFlowZImageTurboExp(
561 TwinFlowZImageTurboExpWeight::Q3_K,
562 ));
563 }
564
565 #[ignore]
566 #[test]
567 fn test_sdxs512_dream_shaper() {
568 run(Preset::SDXS512DreamShaper);
569 }
570
571 #[ignore]
572 #[test]
573 fn test_flux_2_klein_4b() {
574 set_hf_token(include_str!("../token.txt"));
575 run(Preset::Flux2Klein4B(Flux2Klein4BWeight::Q8_0));
576 }
577
578 #[ignore]
579 #[test]
580 fn test_flux_2_klein_base_4b() {
581 set_hf_token(include_str!("../token.txt"));
582 run(Preset::Flux2KleinBase4B(Flux2KleinBase4BWeight::Q8_0));
583 }
584
585 #[ignore]
586 #[test]
587 fn test_flux_2_klein_9b() {
588 set_hf_token(include_str!("../token.txt"));
589 run(Preset::Flux2Klein9B(Flux2Klein9BWeight::Q4_0));
590 }
591
592 #[ignore]
593 #[test]
594 fn test_flux_2_klein_base_9b() {
595 set_hf_token(include_str!("../token.txt"));
596 run(Preset::Flux2KleinBase9B(Flux2KleinBase9BWeight::Q4_0));
597 }
598
599 #[ignore]
600 #[test]
601 fn test_segmind_vega() {
602 run(Preset::SegmindVega);
603 }
604}