1use derive_builder::Builder;
2use hf_hub::api::sync::ApiError;
3use strum::{EnumDiscriminants, EnumString, VariantNames};
4use subenum::subenum;
5
6use crate::{
7 api::{Config, ConfigBuilder, ConfigBuilderError, ModelConfig, ModelConfigBuilder},
8 preset_builder::{
9 chroma, chroma_radiance, diff_instruct_star, dream_shaper_xl_2_1_turbo, flux_1_dev,
10 flux_1_mini, flux_1_schnell, flux_2_dev, flux_2_klein_4b, flux_2_klein_9b,
11 flux_2_klein_base_4b, flux_2_klein_base_9b, juggernaut_xl_11, nitro_sd_realism,
12 nitro_sd_vibrant, ovis_image, qwen_image, sd_turbo, sdxl_base_1_0, sdxl_turbo_1_0,
13 sdxs512_dream_shaper, ssd_1b, stable_diffusion_1_4, stable_diffusion_1_5,
14 stable_diffusion_2_1, stable_diffusion_3_5_large, stable_diffusion_3_5_large_turbo,
15 stable_diffusion_3_5_medium, stable_diffusion_3_medium, twinflow_z_image_turbo,
16 z_image_turbo,
17 },
18};
19
20#[non_exhaustive]
21#[allow(non_camel_case_types)]
22#[subenum(
23 Flux1Weight(derive(Default)),
24 Flux1MiniWeight(derive(Default)),
25 ChromaWeight(derive(Default)),
26 NitroSDRealismWeight(derive(Default)),
27 NitroSDVibrantWeight(derive(Default)),
28 DiffInstructStarWeight(derive(Default)),
29 ChromaRadianceWeight(derive(Default)),
30 SSD1BWeight(derive(Default)),
31 Flux2Weight(derive(Default)),
32 ZImageTurboWeight(derive(Default)),
33 QwenImageWeight(derive(Default)),
34 OvisImageWeight(derive(Default)),
35 TwinFlowZImageTurboExpWeight(derive(Default)),
36 Flux2Klein4BWeight(derive(Default)),
37 Flux2KleinBase4BWeight(derive(Default)),
38 Flux2Klein9BWeight(derive(Default)),
39 Flux2KleinBase9BWeight(derive(Default))
40)]
41#[derive(Debug, Clone, Copy, EnumString, VariantNames)]
42#[strum(ascii_case_insensitive)]
43pub enum WeightType {
45 #[subenum(Flux1MiniWeight)]
46 F32,
47 #[subenum(
48 NitroSDRealismWeight,
49 NitroSDVibrantWeight,
50 DiffInstructStarWeight,
51 SSD1BWeight
52 )]
53 F16,
54 #[subenum(
55 Flux1Weight,
56 ChromaWeight(default),
57 NitroSDRealismWeight,
58 NitroSDVibrantWeight,
59 DiffInstructStarWeight,
60 Flux2Weight,
61 ZImageTurboWeight,
62 QwenImageWeight,
63 OvisImageWeight(default),
64 TwinFlowZImageTurboExpWeight(default),
65 Flux2Klein4BWeight,
66 Flux2KleinBase4BWeight,
67 Flux2Klein9BWeight(default),
68 Flux2KleinBase9BWeight(default)
69 )]
70 Q4_0,
71 #[subenum(Flux2Weight, QwenImageWeight)]
72 Q4_1,
73 #[subenum(
74 NitroSDRealismWeight,
75 NitroSDVibrantWeight,
76 DiffInstructStarWeight,
77 Flux2Weight,
78 ZImageTurboWeight,
79 QwenImageWeight,
80 TwinFlowZImageTurboExpWeight
81 )]
82 Q5_0,
83 #[subenum(Flux2Weight, QwenImageWeight)]
84 Q5_1,
85 #[subenum(
86 Flux1Weight,
87 Flux1MiniWeight(default),
88 ChromaWeight,
89 NitroSDRealismWeight(default),
90 NitroSDVibrantWeight(default),
91 DiffInstructStarWeight(default),
92 ChromaRadianceWeight(default),
93 Flux2Weight,
94 ZImageTurboWeight,
95 QwenImageWeight,
96 OvisImageWeight,
97 TwinFlowZImageTurboExpWeight,
98 Flux2Klein4BWeight(default),
99 Flux2KleinBase4BWeight(default),
100 Flux2Klein9BWeight
101 )]
102 Q8_0,
103 Q8_1,
104 #[subenum(
105 Flux1Weight(default),
106 Flux1MiniWeight,
107 NitroSDRealismWeight,
108 NitroSDVibrantWeight,
109 DiffInstructStarWeight,
110 Flux2Weight(default),
111 ZImageTurboWeight,
112 QwenImageWeight(default)
113 )]
114 Q2_K,
115 #[subenum(
116 Flux1Weight,
117 Flux1MiniWeight,
118 NitroSDRealismWeight,
119 NitroSDVibrantWeight,
120 DiffInstructStarWeight,
121 ZImageTurboWeight,
122 Flux2Weight,
123 QwenImageWeight,
124 TwinFlowZImageTurboExpWeight
125 )]
126 Q3_K,
127 #[subenum(Flux1Weight, ZImageTurboWeight(default), Flux2Weight, QwenImageWeight)]
128 Q4_K,
129 #[subenum(Flux1MiniWeight, Flux2Weight, QwenImageWeight)]
130 Q5_K,
131 #[subenum(
132 Flux1MiniWeight,
133 NitroSDRealismWeight,
134 NitroSDVibrantWeight,
135 DiffInstructStarWeight,
136 Flux2Weight,
137 ZImageTurboWeight,
138 QwenImageWeight,
139 TwinFlowZImageTurboExpWeight
140 )]
141 Q6_K,
142 Q8_K,
143 IQ2_XXS,
144 IQ2_XS,
145 IQ3_XXS,
146 IQ1_S,
147 IQ4_NL,
148 IQ3_S,
149 IQ2_S,
150 IQ4_XS,
151 I8,
152 I16,
153 I32,
154 I64,
155 F64,
156 IQ1_M,
157 #[subenum(
158 Flux1MiniWeight,
159 ChromaWeight,
160 ChromaRadianceWeight,
161 Flux2Weight,
162 ZImageTurboWeight,
163 QwenImageWeight,
164 OvisImageWeight,
165 TwinFlowZImageTurboExpWeight,
166 Flux2Klein4BWeight,
167 Flux2KleinBase4BWeight,
168 Flux2Klein9BWeight,
169 Flux2KleinBase9BWeight
170 )]
171 BF16,
172 TQ1_0,
173 TQ2_0,
174 MXFP4,
175 #[subenum(SSD1BWeight(default), QwenImageWeight)]
176 F8_E4M3,
177}
178
179#[non_exhaustive]
180#[derive(Debug, Clone, Copy, EnumDiscriminants)]
181#[strum_discriminants(derive(EnumString, VariantNames), strum(ascii_case_insensitive))]
182pub enum Preset {
184 StableDiffusion1_4,
185 StableDiffusion1_5,
186 StableDiffusion2_1,
189 StableDiffusion3Medium,
192 StableDiffusion3_5Medium,
195 StableDiffusion3_5Large,
198 StableDiffusion3_5LargeTurbo,
201 SDXLBase1_0,
202 SDTurbo,
204 SDXLTurbo1_0,
206 Flux1Dev(Flux1Weight),
209 Flux1Schnell(Flux1Weight),
212 Flux1Mini(Flux1MiniWeight),
215 JuggernautXL11,
218 Chroma(ChromaWeight),
222 NitroSDRealism(NitroSDRealismWeight),
224 NitroSDVibrant(NitroSDVibrantWeight),
226 DiffInstructStar(DiffInstructStarWeight),
228 ChromaRadiance(ChromaRadianceWeight),
230 SSD1B(SSD1BWeight),
232 Flux2Dev(Flux2Weight),
235 ZImageTurbo(ZImageTurboWeight),
238 QwenImage(QwenImageWeight),
240 OvisImage(OvisImageWeight),
243 DreamShaperXL2_1Turbo,
246 TwinFlowZImageTurboExp(TwinFlowZImageTurboExpWeight),
249 SDXS512DreamShaper,
251 Flux2Klein4B(Flux2Klein4BWeight),
254 Flux2KleinBase4B(Flux2KleinBase4BWeight),
257 Flux2Klein9B(Flux2Klein9BWeight),
260 Flux2KleinBase9B(Flux2KleinBase9BWeight),
263}
264
265impl Preset {
266 fn try_configs_builder(self) -> Result<(ConfigBuilder, ModelConfigBuilder), ApiError> {
267 #[allow(unused_mut)]
268 let mut preset = match self {
269 Preset::StableDiffusion1_4 => stable_diffusion_1_4(),
270 Preset::StableDiffusion1_5 => stable_diffusion_1_5(),
271 Preset::StableDiffusion2_1 => stable_diffusion_2_1(),
272 Preset::StableDiffusion3Medium => stable_diffusion_3_medium(),
273 Preset::SDXLBase1_0 => sdxl_base_1_0(),
274 Preset::Flux1Dev(sd_type_t) => flux_1_dev(sd_type_t),
275 Preset::Flux1Schnell(sd_type_t) => flux_1_schnell(sd_type_t),
276 Preset::SDTurbo => sd_turbo(),
277 Preset::SDXLTurbo1_0 => sdxl_turbo_1_0(),
278 Preset::StableDiffusion3_5Large => stable_diffusion_3_5_large(),
279 Preset::StableDiffusion3_5Medium => stable_diffusion_3_5_medium(),
280 Preset::StableDiffusion3_5LargeTurbo => stable_diffusion_3_5_large_turbo(),
281 Preset::JuggernautXL11 => juggernaut_xl_11(),
282 Preset::Flux1Mini(sd_type_t) => flux_1_mini(sd_type_t),
283 Preset::Chroma(sd_type_t) => chroma(sd_type_t),
284 Preset::NitroSDRealism(sd_type_t) => nitro_sd_realism(sd_type_t),
285 Preset::NitroSDVibrant(sd_type_t) => nitro_sd_vibrant(sd_type_t),
286 Preset::DiffInstructStar(sd_type_t) => diff_instruct_star(sd_type_t),
287 Preset::ChromaRadiance(sd_type_t) => chroma_radiance(sd_type_t),
288 Preset::SSD1B(sd_type_t) => ssd_1b(sd_type_t),
289 Preset::Flux2Dev(sd_type_t) => flux_2_dev(sd_type_t),
290 Preset::ZImageTurbo(sd_type_t) => z_image_turbo(sd_type_t),
291 Preset::QwenImage(sd_type_t) => qwen_image(sd_type_t),
292 Preset::OvisImage(sd_type_t) => ovis_image(sd_type_t),
293 Preset::DreamShaperXL2_1Turbo => dream_shaper_xl_2_1_turbo(),
294 Preset::TwinFlowZImageTurboExp(sd_type_t) => twinflow_z_image_turbo(sd_type_t),
295 Preset::SDXS512DreamShaper => sdxs512_dream_shaper(),
296 Preset::Flux2Klein4B(sd_type_t) => flux_2_klein_4b(sd_type_t),
297 Preset::Flux2KleinBase4B(sd_type_t) => flux_2_klein_base_4b(sd_type_t),
298 Preset::Flux2Klein9B(sd_type_t) => flux_2_klein_9b(sd_type_t),
299 Preset::Flux2KleinBase9B(sd_type_t) => flux_2_klein_base_9b(sd_type_t),
300 };
301
302 #[cfg(feature = "metal")]
305 {
306 if let Ok((_, model_config)) = &mut preset {
307 model_config.clip_on_cpu(true);
308 };
309 }
310 preset
311 }
312}
313
314pub type ConfigsBuilder = (ConfigBuilder, ModelConfigBuilder);
316
317pub type Configs = (Config, ModelConfig);
319
320type ModifierFunction = dyn FnOnce(ConfigsBuilder) -> Result<ConfigsBuilder, ApiError>;
322
323#[derive(Builder)]
324#[builder(
325 name = "PresetBuilder",
326 pattern = "owned",
327 setter(into),
328 build_fn(name = "internal_build", private, error = "ConfigBuilderError")
329)]
330pub struct PresetConfig {
332 prompt: String,
333 preset: Preset,
334 #[builder(private, default = "Vec::new()")]
335 modifiers: Vec<Box<ModifierFunction>>,
336}
337
338impl PresetBuilder {
339 pub fn with_modifier<F>(mut self, f: F) -> Self
341 where
342 F: FnOnce(ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> + 'static,
343 {
344 if self.modifiers.is_none() {
345 self.modifiers = Some(Vec::new());
346 }
347 self.modifiers.as_mut().unwrap().push(Box::new(f));
348 self
349 }
350
351 pub fn build(self) -> Result<Configs, ConfigBuilderError> {
352 let preset = self.internal_build()?;
353 let configs: ConfigsBuilder = preset
354 .try_into()
355 .map_err(|err: ApiError| ConfigBuilderError::ValidationError(err.to_string()))?;
356 let config = configs.0.build()?;
357 let config_model = configs.1.build()?;
358
359 Ok((config, config_model))
360 }
361}
362
363impl TryFrom<PresetConfig> for ConfigsBuilder {
364 type Error = ApiError;
365
366 fn try_from(value: PresetConfig) -> Result<Self, Self::Error> {
367 let mut configs_builder = value.preset.try_configs_builder()?;
368 for m in value.modifiers {
369 configs_builder = m(configs_builder)?;
370 }
371 configs_builder.0.prompt(value.prompt);
372 Ok(configs_builder)
373 }
374}
375
376#[cfg(test)]
377mod tests {
378 use crate::{
379 api::gen_img,
380 preset::{
381 ChromaRadianceWeight, ChromaWeight, DiffInstructStarWeight, Flux1MiniWeight,
382 Flux1Weight, Flux2Klein4BWeight, Flux2Klein9BWeight, Flux2KleinBase4BWeight,
383 Flux2KleinBase9BWeight, Flux2Weight, NitroSDRealismWeight, NitroSDVibrantWeight,
384 OvisImageWeight, QwenImageWeight, SSD1BWeight, TwinFlowZImageTurboExpWeight,
385 ZImageTurboWeight,
386 },
387 util::set_hf_token,
388 };
389
390 use super::{Preset, PresetBuilder};
391 static PROMPT: &str = "a lovely dinosaur made by crochet";
392
393 fn run(preset: Preset) {
394 let (config, mut model_config) = PresetBuilder::default()
395 .preset(preset)
396 .prompt(PROMPT)
397 .build()
398 .unwrap();
399 gen_img(&config, &mut model_config).unwrap();
400 }
401
402 #[ignore]
403 #[test]
404 fn test_stable_diffusion_1_4() {
405 run(Preset::StableDiffusion1_4);
406 }
407
408 #[ignore]
409 #[test]
410 fn test_stable_diffusion_1_5() {
411 run(Preset::StableDiffusion1_5);
412 }
413
414 #[ignore]
415 #[test]
416 fn test_stable_diffusion_2_1() {
417 run(Preset::StableDiffusion2_1);
418 }
419
420 #[ignore]
421 #[test]
422 fn test_stable_diffusion_3_medium_fp16() {
423 set_hf_token(include_str!("../token.txt"));
424 run(Preset::StableDiffusion3Medium);
425 }
426
427 #[ignore]
428 #[test]
429 fn test_sdxl_base_1_0() {
430 run(Preset::SDXLBase1_0);
431 }
432
433 #[ignore]
434 #[test]
435 fn test_flux_1_dev() {
436 set_hf_token(include_str!("../token.txt"));
437 run(Preset::Flux1Dev(Flux1Weight::Q2_K));
438 }
439
440 #[ignore]
441 #[test]
442 fn test_flux_1_schnell() {
443 set_hf_token(include_str!("../token.txt"));
444 run(Preset::Flux1Schnell(Flux1Weight::Q2_K));
445 }
446
447 #[ignore]
448 #[test]
449 fn test_sd_turbo() {
450 run(Preset::SDTurbo);
451 }
452
453 #[ignore]
454 #[test]
455 fn test_sdxl_turbo_1_0_fp16() {
456 run(Preset::SDXLTurbo1_0);
457 }
458
459 #[ignore]
460 #[test]
461 fn test_stable_diffusion_3_5_medium_fp16() {
462 set_hf_token(include_str!("../token.txt"));
463 run(Preset::StableDiffusion3_5Medium);
464 }
465
466 #[ignore]
467 #[test]
468 fn test_stable_diffusion_3_5_large_fp16() {
469 set_hf_token(include_str!("../token.txt"));
470 run(Preset::StableDiffusion3_5Large);
471 }
472
473 #[ignore]
474 #[test]
475 fn test_stable_diffusion_3_5_large_turbo_fp16() {
476 set_hf_token(include_str!("../token.txt"));
477 run(Preset::StableDiffusion3_5LargeTurbo);
478 }
479
480 #[ignore]
481 #[test]
482 fn test_juggernaut_xl_11() {
483 set_hf_token(include_str!("../token.txt"));
484 run(Preset::JuggernautXL11);
485 }
486
487 #[ignore]
488 #[test]
489 fn test_flux_1_mini() {
490 set_hf_token(include_str!("../token.txt"));
491 run(Preset::Flux1Mini(Flux1MiniWeight::Q2_K));
492 }
493
494 #[ignore]
495 #[test]
496 fn test_chroma() {
497 set_hf_token(include_str!("../token.txt"));
498 run(Preset::Chroma(ChromaWeight::Q4_0));
499 }
500
501 #[ignore]
502 #[test]
503 fn test_nitro_sd_realism() {
504 run(Preset::NitroSDRealism(NitroSDRealismWeight::Q8_0));
505 }
506
507 #[ignore]
508 #[test]
509 fn test_nitro_sd_vibrant() {
510 run(Preset::NitroSDVibrant(NitroSDVibrantWeight::Q8_0));
511 }
512
513 #[ignore]
514 #[test]
515 fn test_diff_instruct_star() {
516 run(Preset::DiffInstructStar(DiffInstructStarWeight::Q8_0));
517 }
518
519 #[ignore]
520 #[test]
521 fn test_chroma_radiance() {
522 run(Preset::ChromaRadiance(ChromaRadianceWeight::Q8_0));
523 }
524
525 #[ignore]
526 #[test]
527 fn test_ssd_1b() {
528 run(Preset::SSD1B(SSD1BWeight::F8_E4M3));
529 }
530
531 #[ignore]
532 #[test]
533 fn test_flux_2_dev() {
534 set_hf_token(include_str!("../token.txt"));
535 run(Preset::Flux2Dev(Flux2Weight::Q2_K));
536 }
537
538 #[ignore]
539 #[test]
540 fn test_z_image_turbo() {
541 set_hf_token(include_str!("../token.txt"));
542 run(Preset::ZImageTurbo(ZImageTurboWeight::Q2_K));
543 }
544
545 #[ignore]
546 #[test]
547 fn test_qwen_image() {
548 run(Preset::QwenImage(QwenImageWeight::Q2_K));
549 }
550
551 #[ignore]
552 #[test]
553 fn test_ovis_image() {
554 set_hf_token(include_str!("../token.txt"));
555 run(Preset::OvisImage(OvisImageWeight::Q4_0));
556 }
557
558 #[ignore]
559 #[test]
560 fn test_dreamshaper_xl_2_1_turbo() {
561 run(Preset::DreamShaperXL2_1Turbo);
562 }
563
564 #[ignore]
565 #[test]
566 fn test_twinflow_z_image_turbo_exp() {
567 set_hf_token(include_str!("../token.txt"));
568 run(Preset::TwinFlowZImageTurboExp(
569 TwinFlowZImageTurboExpWeight::Q3_K,
570 ));
571 }
572
573 #[ignore]
574 #[test]
575 fn test_sdxs512_dream_shaper() {
576 run(Preset::SDXS512DreamShaper);
577 }
578
579 #[ignore]
580 #[test]
581 fn test_flux_2_klein_4b() {
582 set_hf_token(include_str!("../token.txt"));
583 run(Preset::Flux2Klein4B(Flux2Klein4BWeight::Q8_0));
584 }
585
586 #[ignore]
587 #[test]
588 fn test_flux_2_klein_base_4b() {
589 set_hf_token(include_str!("../token.txt"));
590 run(Preset::Flux2KleinBase4B(Flux2KleinBase4BWeight::Q8_0));
591 }
592
593 #[ignore]
594 #[test]
595 fn test_flux_2_klein_9b() {
596 set_hf_token(include_str!("../token.txt"));
597 run(Preset::Flux2Klein9B(Flux2Klein9BWeight::Q4_0));
598 }
599
600 #[ignore]
601 #[test]
602 fn test_flux_2_klein_base_9b() {
603 set_hf_token(include_str!("../token.txt"));
604 run(Preset::Flux2KleinBase9B(Flux2KleinBase9BWeight::Q4_0));
605 }
606}