1use derive_builder::Builder;
2use hf_hub::api::sync::ApiError;
3use strum::{EnumDiscriminants, EnumString, VariantNames};
4use subenum::subenum;
5
6use crate::{
7 api::{Config, ConfigBuilder, ConfigBuilderError, ModelConfig, ModelConfigBuilder},
8 preset_builder::{
9 chroma, chroma_radiance, diff_instruct_star, dream_shaper_xl_2_1_turbo, flux_1_dev,
10 flux_1_mini, flux_1_schnell, flux_2_dev, juggernaut_xl_11, nitro_sd_realism,
11 nitro_sd_vibrant, ovis_image, qwen_image, sd_turbo, sdxl_base_1_0, sdxl_turbo_1_0, ssd_1b,
12 stable_diffusion_1_4, stable_diffusion_1_5, stable_diffusion_2_1,
13 stable_diffusion_3_5_large, stable_diffusion_3_5_large_turbo, stable_diffusion_3_5_medium,
14 stable_diffusion_3_medium, twinflow_z_image_turbo, z_image_turbo,
15 },
16};
17
18#[non_exhaustive]
19#[allow(non_camel_case_types)]
20#[subenum(
21 Flux1Weight(derive(Default)),
22 Flux1MiniWeight(derive(Default)),
23 ChromaWeight(derive(Default)),
24 NitroSDRealismWeight(derive(Default)),
25 NitroSDVibrantWeight(derive(Default)),
26 DiffInstructStarWeight(derive(Default)),
27 ChromaRadianceWeight(derive(Default)),
28 SSD1BWeight(derive(Default)),
29 Flux2Weight(derive(Default)),
30 ZImageTurboWeight(derive(Default)),
31 QwenImageWeight(derive(Default)),
32 OvisImageWeight(derive(Default)),
33 TwinFlowZImageTurboExpWeight(derive(Default))
34)]
35#[derive(Debug, Clone, Copy, EnumString, VariantNames)]
36#[strum(ascii_case_insensitive)]
37pub enum WeightType {
39 #[subenum(Flux1MiniWeight)]
40 F32,
41 #[subenum(
42 NitroSDRealismWeight,
43 NitroSDVibrantWeight,
44 DiffInstructStarWeight,
45 SSD1BWeight
46 )]
47 F16,
48 #[subenum(
49 Flux1Weight,
50 ChromaWeight(default),
51 NitroSDRealismWeight,
52 NitroSDVibrantWeight,
53 DiffInstructStarWeight,
54 Flux2Weight,
55 ZImageTurboWeight,
56 QwenImageWeight,
57 OvisImageWeight(default),
58 TwinFlowZImageTurboExpWeight(default)
59 )]
60 Q4_0,
61 #[subenum(Flux2Weight, QwenImageWeight)]
62 Q4_1,
63 #[subenum(
64 NitroSDRealismWeight,
65 NitroSDVibrantWeight,
66 DiffInstructStarWeight,
67 Flux2Weight,
68 ZImageTurboWeight,
69 QwenImageWeight,
70 TwinFlowZImageTurboExpWeight
71 )]
72 Q5_0,
73 #[subenum(Flux2Weight, QwenImageWeight)]
74 Q5_1,
75 #[subenum(
76 Flux1Weight,
77 Flux1MiniWeight(default),
78 ChromaWeight,
79 NitroSDRealismWeight(default),
80 NitroSDVibrantWeight(default),
81 DiffInstructStarWeight(default),
82 ChromaRadianceWeight(default),
83 Flux2Weight,
84 ZImageTurboWeight,
85 QwenImageWeight,
86 OvisImageWeight,
87 TwinFlowZImageTurboExpWeight
88 )]
89 Q8_0,
90 Q8_1,
91 #[subenum(
92 Flux1Weight(default),
93 Flux1MiniWeight,
94 NitroSDRealismWeight,
95 NitroSDVibrantWeight,
96 DiffInstructStarWeight,
97 Flux2Weight(default),
98 ZImageTurboWeight,
99 QwenImageWeight(default)
100 )]
101 Q2_K,
102 #[subenum(
103 Flux1Weight,
104 Flux1MiniWeight,
105 NitroSDRealismWeight,
106 NitroSDVibrantWeight,
107 DiffInstructStarWeight,
108 ZImageTurboWeight,
109 Flux2Weight,
110 QwenImageWeight,
111 TwinFlowZImageTurboExpWeight
112 )]
113 Q3_K,
114 #[subenum(Flux1Weight, ZImageTurboWeight(default), Flux2Weight, QwenImageWeight)]
115 Q4_K,
116 #[subenum(Flux1MiniWeight, Flux2Weight, QwenImageWeight)]
117 Q5_K,
118 #[subenum(
119 Flux1MiniWeight,
120 NitroSDRealismWeight,
121 NitroSDVibrantWeight,
122 DiffInstructStarWeight,
123 Flux2Weight,
124 ZImageTurboWeight,
125 QwenImageWeight,
126 TwinFlowZImageTurboExpWeight
127 )]
128 Q6_K,
129 Q8_K,
130 IQ2_XXS,
131 IQ2_XS,
132 IQ3_XXS,
133 IQ1_S,
134 IQ4_NL,
135 IQ3_S,
136 IQ2_S,
137 IQ4_XS,
138 I8,
139 I16,
140 I32,
141 I64,
142 F64,
143 IQ1_M,
144 #[subenum(
145 Flux1MiniWeight,
146 ChromaWeight,
147 ChromaRadianceWeight,
148 Flux2Weight,
149 ZImageTurboWeight,
150 QwenImageWeight,
151 OvisImageWeight,
152 TwinFlowZImageTurboExpWeight
153 )]
154 BF16,
155 TQ1_0,
156 TQ2_0,
157 MXFP4,
158 #[subenum(SSD1BWeight(default), QwenImageWeight)]
159 F8_E4M3,
160}
161
162#[non_exhaustive]
163#[derive(Debug, Clone, Copy, EnumDiscriminants)]
164#[strum_discriminants(derive(EnumString, VariantNames), strum(ascii_case_insensitive))]
165pub enum Preset {
167 StableDiffusion1_4,
168 StableDiffusion1_5,
169 StableDiffusion2_1,
172 StableDiffusion3Medium,
175 StableDiffusion3_5Medium,
178 StableDiffusion3_5Large,
181 StableDiffusion3_5LargeTurbo,
184 SDXLBase1_0,
185 SDTurbo,
187 SDXLTurbo1_0,
189 Flux1Dev(Flux1Weight),
192 Flux1Schnell(Flux1Weight),
195 Flux1Mini(Flux1MiniWeight),
198 JuggernautXL11,
201 Chroma(ChromaWeight),
205 NitroSDRealism(NitroSDRealismWeight),
207 NitroSDVibrant(NitroSDVibrantWeight),
209 DiffInstructStar(DiffInstructStarWeight),
211 ChromaRadiance(ChromaRadianceWeight),
213 SSD1B(SSD1BWeight),
215 Flux2Dev(Flux2Weight),
218 ZImageTurbo(ZImageTurboWeight),
221 QwenImage(QwenImageWeight),
223 OvisImage(OvisImageWeight),
226 DreamShaperXL2_1Turbo,
229 TwinFlowZImageTurboExp(TwinFlowZImageTurboExpWeight),
232}
233
234impl Preset {
235 fn try_configs_builder(self) -> Result<(ConfigBuilder, ModelConfigBuilder), ApiError> {
236 #[allow(unused_mut)]
237 let mut preset = match self {
238 Preset::StableDiffusion1_4 => stable_diffusion_1_4(),
239 Preset::StableDiffusion1_5 => stable_diffusion_1_5(),
240 Preset::StableDiffusion2_1 => stable_diffusion_2_1(),
241 Preset::StableDiffusion3Medium => stable_diffusion_3_medium(),
242 Preset::SDXLBase1_0 => sdxl_base_1_0(),
243 Preset::Flux1Dev(sd_type_t) => flux_1_dev(sd_type_t),
244 Preset::Flux1Schnell(sd_type_t) => flux_1_schnell(sd_type_t),
245 Preset::SDTurbo => sd_turbo(),
246 Preset::SDXLTurbo1_0 => sdxl_turbo_1_0(),
247 Preset::StableDiffusion3_5Large => stable_diffusion_3_5_large(),
248 Preset::StableDiffusion3_5Medium => stable_diffusion_3_5_medium(),
249 Preset::StableDiffusion3_5LargeTurbo => stable_diffusion_3_5_large_turbo(),
250 Preset::JuggernautXL11 => juggernaut_xl_11(),
251 Preset::Flux1Mini(sd_type_t) => flux_1_mini(sd_type_t),
252 Preset::Chroma(sd_type_t) => chroma(sd_type_t),
253 Preset::NitroSDRealism(sd_type_t) => nitro_sd_realism(sd_type_t),
254 Preset::NitroSDVibrant(sd_type_t) => nitro_sd_vibrant(sd_type_t),
255 Preset::DiffInstructStar(sd_type_t) => diff_instruct_star(sd_type_t),
256 Preset::ChromaRadiance(sd_type_t) => chroma_radiance(sd_type_t),
257 Preset::SSD1B(sd_type_t) => ssd_1b(sd_type_t),
258 Preset::Flux2Dev(sd_type_t) => flux_2_dev(sd_type_t),
259 Preset::ZImageTurbo(sd_type_t) => z_image_turbo(sd_type_t),
260 Preset::QwenImage(sd_type_t) => qwen_image(sd_type_t),
261 Preset::OvisImage(sd_type_t) => ovis_image(sd_type_t),
262 Preset::DreamShaperXL2_1Turbo => dream_shaper_xl_2_1_turbo(),
263 Preset::TwinFlowZImageTurboExp(sd_type_t) => twinflow_z_image_turbo(sd_type_t),
264 };
265
266 #[cfg(feature = "metal")]
269 {
270 if let Ok((_, model_config)) = &mut preset {
271 model_config.clip_on_cpu(true);
272 };
273 }
274 preset
275 }
276}
277
278pub type ConfigsBuilder = (ConfigBuilder, ModelConfigBuilder);
280
281pub type Configs = (Config, ModelConfig);
283
284type ModifierFunction = dyn FnOnce(ConfigsBuilder) -> Result<ConfigsBuilder, ApiError>;
286
287#[derive(Builder)]
288#[builder(
289 name = "PresetBuilder",
290 pattern = "owned",
291 setter(into),
292 build_fn(name = "internal_build", private, error = "ConfigBuilderError")
293)]
294pub struct PresetConfig {
296 prompt: String,
297 preset: Preset,
298 #[builder(private, default = "Vec::new()")]
299 modifiers: Vec<Box<ModifierFunction>>,
300}
301
302impl PresetBuilder {
303 pub fn with_modifier<F>(mut self, f: F) -> Self
305 where
306 F: FnOnce(ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> + 'static,
307 {
308 if self.modifiers.is_none() {
309 self.modifiers = Some(Vec::new());
310 }
311 self.modifiers.as_mut().unwrap().push(Box::new(f));
312 self
313 }
314
315 pub fn build(self) -> Result<Configs, ConfigBuilderError> {
316 let preset = self.internal_build()?;
317 let configs: ConfigsBuilder = preset
318 .try_into()
319 .map_err(|err: ApiError| ConfigBuilderError::ValidationError(err.to_string()))?;
320 let config = configs.0.build()?;
321 let config_model = configs.1.build()?;
322
323 Ok((config, config_model))
324 }
325}
326
327impl TryFrom<PresetConfig> for ConfigsBuilder {
328 type Error = ApiError;
329
330 fn try_from(value: PresetConfig) -> Result<Self, Self::Error> {
331 let mut configs_builder = value.preset.try_configs_builder()?;
332 for m in value.modifiers {
333 configs_builder = m(configs_builder)?;
334 }
335 configs_builder.0.prompt(value.prompt);
336 Ok(configs_builder)
337 }
338}
339
340#[cfg(test)]
341mod tests {
342 use crate::{
343 api::gen_img,
344 preset::{
345 ChromaRadianceWeight, ChromaWeight, DiffInstructStarWeight, Flux1MiniWeight,
346 Flux1Weight, Flux2Weight, NitroSDRealismWeight, NitroSDVibrantWeight, OvisImageWeight,
347 QwenImageWeight, SSD1BWeight, TwinFlowZImageTurboExpWeight, ZImageTurboWeight,
348 },
349 util::set_hf_token,
350 };
351
352 use super::{Preset, PresetBuilder};
353 static PROMPT: &str = "a lovely dinosaur made by crochet";
354
355 fn run(preset: Preset) {
356 let (config, mut model_config) = PresetBuilder::default()
357 .preset(preset)
358 .prompt(PROMPT)
359 .build()
360 .unwrap();
361 gen_img(&config, &mut model_config).unwrap();
362 }
363
364 #[ignore]
365 #[test]
366 fn test_stable_diffusion_1_4() {
367 run(Preset::StableDiffusion1_4);
368 }
369
370 #[ignore]
371 #[test]
372 fn test_stable_diffusion_1_5() {
373 run(Preset::StableDiffusion1_5);
374 }
375
376 #[ignore]
377 #[test]
378 fn test_stable_diffusion_2_1() {
379 run(Preset::StableDiffusion2_1);
380 }
381
382 #[ignore]
383 #[test]
384 fn test_stable_diffusion_3_medium_fp16() {
385 set_hf_token(include_str!("../token.txt"));
386 run(Preset::StableDiffusion3Medium);
387 }
388
389 #[ignore]
390 #[test]
391 fn test_sdxl_base_1_0() {
392 run(Preset::SDXLBase1_0);
393 }
394
395 #[ignore]
396 #[test]
397 fn test_flux_1_dev() {
398 set_hf_token(include_str!("../token.txt"));
399 run(Preset::Flux1Dev(Flux1Weight::Q2_K));
400 }
401
402 #[ignore]
403 #[test]
404 fn test_flux_1_schnell() {
405 set_hf_token(include_str!("../token.txt"));
406 run(Preset::Flux1Schnell(Flux1Weight::Q2_K));
407 }
408
409 #[ignore]
410 #[test]
411 fn test_sd_turbo() {
412 run(Preset::SDTurbo);
413 }
414
415 #[ignore]
416 #[test]
417 fn test_sdxl_turbo_1_0_fp16() {
418 run(Preset::SDXLTurbo1_0);
419 }
420
421 #[ignore]
422 #[test]
423 fn test_stable_diffusion_3_5_medium_fp16() {
424 set_hf_token(include_str!("../token.txt"));
425 run(Preset::StableDiffusion3_5Medium);
426 }
427
428 #[ignore]
429 #[test]
430 fn test_stable_diffusion_3_5_large_fp16() {
431 set_hf_token(include_str!("../token.txt"));
432 run(Preset::StableDiffusion3_5Large);
433 }
434
435 #[ignore]
436 #[test]
437 fn test_stable_diffusion_3_5_large_turbo_fp16() {
438 set_hf_token(include_str!("../token.txt"));
439 run(Preset::StableDiffusion3_5LargeTurbo);
440 }
441
442 #[ignore]
443 #[test]
444 fn test_juggernaut_xl_11() {
445 set_hf_token(include_str!("../token.txt"));
446 run(Preset::JuggernautXL11);
447 }
448
449 #[ignore]
450 #[test]
451 fn test_flux_1_mini() {
452 set_hf_token(include_str!("../token.txt"));
453 run(Preset::Flux1Mini(Flux1MiniWeight::Q2_K));
454 }
455
456 #[ignore]
457 #[test]
458 fn test_chroma() {
459 set_hf_token(include_str!("../token.txt"));
460 run(Preset::Chroma(ChromaWeight::Q4_0));
461 }
462
463 #[ignore]
464 #[test]
465 fn test_nitro_sd_realism() {
466 run(Preset::NitroSDRealism(NitroSDRealismWeight::Q8_0));
467 }
468
469 #[ignore]
470 #[test]
471 fn test_nitro_sd_vibrant() {
472 run(Preset::NitroSDVibrant(NitroSDVibrantWeight::Q8_0));
473 }
474
475 #[ignore]
476 #[test]
477 fn test_diff_instruct_star() {
478 run(Preset::DiffInstructStar(DiffInstructStarWeight::Q8_0));
479 }
480
481 #[ignore]
482 #[test]
483 fn test_chroma_radiance() {
484 run(Preset::ChromaRadiance(ChromaRadianceWeight::Q8_0));
485 }
486
487 #[ignore]
488 #[test]
489 fn test_ssd_1b() {
490 run(Preset::SSD1B(SSD1BWeight::F8_E4M3));
491 }
492
493 #[ignore]
494 #[test]
495 fn test_flux_2_dev() {
496 set_hf_token(include_str!("../token.txt"));
497 run(Preset::Flux2Dev(Flux2Weight::Q2_K));
498 }
499
500 #[ignore]
501 #[test]
502 fn test_z_image_turbo() {
503 set_hf_token(include_str!("../token.txt"));
504 run(Preset::ZImageTurbo(ZImageTurboWeight::Q2_K));
505 }
506
507 #[ignore]
508 #[test]
509 fn test_qwen_image() {
510 run(Preset::QwenImage(QwenImageWeight::Q2_K));
511 }
512
513 #[ignore]
514 #[test]
515 fn test_ovis_image() {
516 set_hf_token(include_str!("../token.txt"));
517 run(Preset::OvisImage(OvisImageWeight::Q4_0));
518 }
519
520 #[ignore]
521 #[test]
522 fn test_dreamshaper_xl_2_1_turbo() {
523 run(Preset::DreamShaperXL2_1Turbo);
524 }
525
526 #[ignore]
527 #[test]
528 fn test_twinflow_z_image_turbo_exp() {
529 set_hf_token(include_str!("../token.txt"));
530 run(Preset::TwinFlowZImageTurboExp(
531 TwinFlowZImageTurboExpWeight::Q3_K,
532 ));
533 }
534}