1use derive_builder::Builder;
2use hf_hub::api::sync::ApiError;
3use strum::{EnumDiscriminants, EnumString, VariantNames};
4use subenum::subenum;
5
6use crate::{
7 api::{Config, ConfigBuilder, ConfigBuilderError, ModelConfig, ModelConfigBuilder},
8 preset_builder::{
9 chroma, chroma_radiance, diff_instruct_star, flux_1_dev, flux_1_mini, flux_1_schnell,
10 flux_2_dev, juggernaut_xl_11, nitro_sd_realism, nitro_sd_vibrant, ovis_image, qwen_image,
11 sd_turbo, sdxl_base_1_0, sdxl_turbo_1_0_fp16, ssd_1b, stable_diffusion_1_4,
12 stable_diffusion_1_5, stable_diffusion_2_1, stable_diffusion_3_5_large_fp16,
13 stable_diffusion_3_5_large_turbo_fp16, stable_diffusion_3_5_medium_fp16,
14 stable_diffusion_3_medium_fp16, z_image_turbo,
15 },
16};
17
18#[non_exhaustive]
19#[allow(non_camel_case_types)]
20#[subenum(
21 Flux1Weight(derive(Default)),
22 Flux1MiniWeight(derive(Default)),
23 ChromaWeight(derive(Default)),
24 NitroSDRealismWeight(derive(Default)),
25 NitroSDVibrantWeight(derive(Default)),
26 DiffInstructStarWeight(derive(Default)),
27 ChromaRadianceWeight(derive(Default)),
28 SSD1BWeight(derive(Default)),
29 Flux2Weight(derive(Default)),
30 ZImageTurboWeight(derive(Default)),
31 QwenImageWeight(derive(Default)),
32 OvisImageWeight(derive(Default))
33)]
34#[derive(Debug, Clone, Copy, EnumString, VariantNames)]
35#[strum(ascii_case_insensitive)]
36pub enum WeightType {
38 #[subenum(Flux1MiniWeight)]
39 F32,
40 #[subenum(
41 NitroSDRealismWeight,
42 NitroSDVibrantWeight,
43 DiffInstructStarWeight,
44 SSD1BWeight
45 )]
46 F16,
47 #[subenum(
48 Flux1Weight,
49 ChromaWeight(default),
50 NitroSDRealismWeight,
51 NitroSDVibrantWeight,
52 DiffInstructStarWeight,
53 Flux2Weight,
54 ZImageTurboWeight,
55 QwenImageWeight,
56 OvisImageWeight(default)
57 )]
58 Q4_0,
59 #[subenum(Flux2Weight, QwenImageWeight)]
60 Q4_1,
61 #[subenum(
62 NitroSDRealismWeight,
63 NitroSDVibrantWeight,
64 DiffInstructStarWeight,
65 Flux2Weight,
66 ZImageTurboWeight,
67 QwenImageWeight
68 )]
69 Q5_0,
70 #[subenum(Flux2Weight, QwenImageWeight)]
71 Q5_1,
72 #[subenum(
73 Flux1Weight,
74 Flux1MiniWeight(default),
75 ChromaWeight,
76 NitroSDRealismWeight(default),
77 NitroSDVibrantWeight(default),
78 DiffInstructStarWeight(default),
79 ChromaRadianceWeight(default),
80 Flux2Weight,
81 ZImageTurboWeight,
82 QwenImageWeight,
83 OvisImageWeight
84 )]
85 Q8_0,
86 Q8_1,
87 #[subenum(
88 Flux1Weight(default),
89 Flux1MiniWeight,
90 NitroSDRealismWeight,
91 NitroSDVibrantWeight,
92 DiffInstructStarWeight,
93 Flux2Weight(default),
94 ZImageTurboWeight,
95 QwenImageWeight(default)
96 )]
97 Q2_K,
98 #[subenum(
99 Flux1Weight,
100 Flux1MiniWeight,
101 NitroSDRealismWeight,
102 NitroSDVibrantWeight,
103 DiffInstructStarWeight,
104 ZImageTurboWeight,
105 Flux2Weight,
106 QwenImageWeight
107 )]
108 Q3_K,
109 #[subenum(Flux1Weight, ZImageTurboWeight(default), Flux2Weight, QwenImageWeight)]
110 Q4_K,
111 #[subenum(Flux1MiniWeight, Flux2Weight, QwenImageWeight)]
112 Q5_K,
113 #[subenum(
114 Flux1MiniWeight,
115 NitroSDRealismWeight,
116 NitroSDVibrantWeight,
117 DiffInstructStarWeight,
118 Flux2Weight,
119 ZImageTurboWeight,
120 QwenImageWeight
121 )]
122 Q6_K,
123 Q8_K,
124 IQ2_XXS,
125 IQ2_XS,
126 IQ3_XXS,
127 IQ1_S,
128 IQ4_NL,
129 IQ3_S,
130 IQ2_S,
131 IQ4_XS,
132 I8,
133 I16,
134 I32,
135 I64,
136 F64,
137 IQ1_M,
138 #[subenum(
139 Flux1MiniWeight,
140 ChromaWeight,
141 ChromaRadianceWeight,
142 Flux2Weight,
143 ZImageTurboWeight,
144 QwenImageWeight,
145 OvisImageWeight
146 )]
147 BF16,
148 TQ1_0,
149 TQ2_0,
150 MXFP4,
151 #[subenum(SSD1BWeight(default), QwenImageWeight)]
152 F8_E4M3,
153}
154
155#[non_exhaustive]
156#[derive(Debug, Clone, Copy, EnumDiscriminants)]
157#[strum_discriminants(derive(EnumString, VariantNames), strum(ascii_case_insensitive))]
158pub enum Preset {
160 StableDiffusion1_4,
161 StableDiffusion1_5,
162 StableDiffusion2_1,
165 StableDiffusion3MediumFp16,
168 StableDiffusion3_5MediumFp16,
171 StableDiffusion3_5LargeFp16,
174 StableDiffusion3_5LargeTurboFp16,
177 SDXLBase1_0,
178 SDTurbo,
180 SDXLTurbo1_0Fp16,
182 Flux1Dev(Flux1Weight),
185 Flux1Schnell(Flux1Weight),
188 Flux1Mini(Flux1MiniWeight),
191 JuggernautXL11,
194 Chroma(ChromaWeight),
198 NitroSDRealism(NitroSDRealismWeight),
200 NitroSDVibrant(NitroSDVibrantWeight),
202 DiffInstructStar(DiffInstructStarWeight),
204 ChromaRadiance(ChromaRadianceWeight),
206 SSD1B(SSD1BWeight),
208 Flux2Dev(Flux2Weight),
211 ZImageTurbo(ZImageTurboWeight),
214 QwenImage(QwenImageWeight),
216 OvisImage(OvisImageWeight),
219}
220
221impl Preset {
222 fn try_configs_builder(self) -> Result<(ConfigBuilder, ModelConfigBuilder), ApiError> {
223 #[allow(unused_mut)]
224 let mut preset = match self {
225 Preset::StableDiffusion1_4 => stable_diffusion_1_4(),
226 Preset::StableDiffusion1_5 => stable_diffusion_1_5(),
227 Preset::StableDiffusion2_1 => stable_diffusion_2_1(),
228 Preset::StableDiffusion3MediumFp16 => stable_diffusion_3_medium_fp16(),
229 Preset::SDXLBase1_0 => sdxl_base_1_0(),
230 Preset::Flux1Dev(sd_type_t) => flux_1_dev(sd_type_t),
231 Preset::Flux1Schnell(sd_type_t) => flux_1_schnell(sd_type_t),
232 Preset::SDTurbo => sd_turbo(),
233 Preset::SDXLTurbo1_0Fp16 => sdxl_turbo_1_0_fp16(),
234 Preset::StableDiffusion3_5LargeFp16 => stable_diffusion_3_5_large_fp16(),
235 Preset::StableDiffusion3_5MediumFp16 => stable_diffusion_3_5_medium_fp16(),
236 Preset::StableDiffusion3_5LargeTurboFp16 => stable_diffusion_3_5_large_turbo_fp16(),
237 Preset::JuggernautXL11 => juggernaut_xl_11(),
238 Preset::Flux1Mini(sd_type_t) => flux_1_mini(sd_type_t),
239 Preset::Chroma(sd_type_t) => chroma(sd_type_t),
240 Preset::NitroSDRealism(sd_type_t) => nitro_sd_realism(sd_type_t),
241 Preset::NitroSDVibrant(sd_type_t) => nitro_sd_vibrant(sd_type_t),
242 Preset::DiffInstructStar(sd_type_t) => diff_instruct_star(sd_type_t),
243 Preset::ChromaRadiance(sd_type_t) => chroma_radiance(sd_type_t),
244 Preset::SSD1B(sd_type_t) => ssd_1b(sd_type_t),
245 Preset::Flux2Dev(sd_type_t) => flux_2_dev(sd_type_t),
246 Preset::ZImageTurbo(sd_type_t) => z_image_turbo(sd_type_t),
247 Preset::QwenImage(sd_type_t) => qwen_image(sd_type_t),
248 Preset::OvisImage(sd_type_t) => ovis_image(sd_type_t),
249 };
250
251 #[cfg(feature = "metal")]
254 {
255 if let Ok((_, model_config)) = &mut preset {
256 model_config.clip_on_cpu(true);
257 };
258 }
259 preset
260 }
261}
262
263pub type ConfigsBuilder = (ConfigBuilder, ModelConfigBuilder);
265
266pub type Configs = (Config, ModelConfig);
268
269type Modifier = dyn FnOnce(ConfigsBuilder) -> Result<ConfigsBuilder, ApiError>;
271
272#[derive(Builder)]
273#[builder(
274 name = "PresetBuilder",
275 pattern = "owned",
276 setter(into),
277 build_fn(name = "internal_build", private, error = "ConfigBuilderError")
278)]
279pub struct PresetConfig {
281 prompt: String,
282 preset: Preset,
283 #[builder(private, default = "Vec::new()")]
284 modifiers: Vec<Box<Modifier>>,
285}
286
287impl PresetBuilder {
288 pub fn with_modifier<F>(mut self, f: F) -> Self
290 where
291 F: FnOnce(ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> + 'static,
292 {
293 if self.modifiers.is_none() {
294 self.modifiers = Some(Vec::new());
295 }
296 self.modifiers.as_mut().unwrap().push(Box::new(f));
297 self
298 }
299
300 pub fn build(self) -> Result<Configs, ConfigBuilderError> {
301 let preset = self.internal_build()?;
302 let configs: ConfigsBuilder = preset
303 .try_into()
304 .map_err(|err: ApiError| ConfigBuilderError::ValidationError(err.to_string()))?;
305 let config = configs.0.build()?;
306 let config_model = configs.1.build()?;
307
308 Ok((config, config_model))
309 }
310}
311
312impl TryFrom<PresetConfig> for ConfigsBuilder {
313 type Error = ApiError;
314
315 fn try_from(value: PresetConfig) -> Result<Self, Self::Error> {
316 let mut configs_builder = value.preset.try_configs_builder()?;
317 for m in value.modifiers {
318 configs_builder = m(configs_builder)?;
319 }
320 configs_builder.0.prompt(value.prompt);
321 Ok(configs_builder)
322 }
323}
324
325#[cfg(test)]
326mod tests {
327 use crate::{
328 api::gen_img,
329 preset::{
330 ChromaRadianceWeight, ChromaWeight, DiffInstructStarWeight, Flux1MiniWeight,
331 Flux1Weight, Flux2Weight, NitroSDRealismWeight, NitroSDVibrantWeight, OvisImageWeight,
332 QwenImageWeight, SSD1BWeight, ZImageTurboWeight,
333 },
334 util::set_hf_token,
335 };
336
337 use super::{Preset, PresetBuilder};
338 static PROMPT: &str = "a lovely dinosaur made by crochet";
339
340 fn run(preset: Preset) {
341 let (config, mut model_config) = PresetBuilder::default()
342 .preset(preset)
343 .prompt(PROMPT)
344 .build()
345 .unwrap();
346 gen_img(&config, &mut model_config).unwrap();
347 }
348
349 #[ignore]
350 #[test]
351 fn test_stable_diffusion_1_4() {
352 run(Preset::StableDiffusion1_4);
353 }
354
355 #[ignore]
356 #[test]
357 fn test_stable_diffusion_1_5() {
358 run(Preset::StableDiffusion1_5);
359 }
360
361 #[ignore]
362 #[test]
363 fn test_stable_diffusion_2_1() {
364 run(Preset::StableDiffusion2_1);
365 }
366
367 #[ignore]
368 #[test]
369 fn test_stable_diffusion_3_medium_fp16() {
370 set_hf_token(include_str!("../token.txt"));
371 run(Preset::StableDiffusion3MediumFp16);
372 }
373
374 #[ignore]
375 #[test]
376 fn test_sdxl_base_1_0() {
377 run(Preset::SDXLBase1_0);
378 }
379
380 #[ignore]
381 #[test]
382 fn test_flux_1_dev() {
383 set_hf_token(include_str!("../token.txt"));
384 run(Preset::Flux1Dev(Flux1Weight::Q2_K));
385 }
386
387 #[ignore]
388 #[test]
389 fn test_flux_1_schnell() {
390 set_hf_token(include_str!("../token.txt"));
391 run(Preset::Flux1Schnell(Flux1Weight::Q2_K));
392 }
393
394 #[ignore]
395 #[test]
396 fn test_sd_turbo() {
397 run(Preset::SDTurbo);
398 }
399
400 #[ignore]
401 #[test]
402 fn test_sdxl_turbo_1_0_fp16() {
403 run(Preset::SDXLTurbo1_0Fp16);
404 }
405
406 #[ignore]
407 #[test]
408 fn test_stable_diffusion_3_5_medium_fp16() {
409 set_hf_token(include_str!("../token.txt"));
410 run(Preset::StableDiffusion3_5MediumFp16);
411 }
412
413 #[ignore]
414 #[test]
415 fn test_stable_diffusion_3_5_large_fp16() {
416 set_hf_token(include_str!("../token.txt"));
417 run(Preset::StableDiffusion3_5LargeFp16);
418 }
419
420 #[ignore]
421 #[test]
422 fn test_stable_diffusion_3_5_large_turbo_fp16() {
423 set_hf_token(include_str!("../token.txt"));
424 run(Preset::StableDiffusion3_5LargeTurboFp16);
425 }
426
427 #[ignore]
428 #[test]
429 fn test_juggernaut_xl_11() {
430 set_hf_token(include_str!("../token.txt"));
431 run(Preset::JuggernautXL11);
432 }
433
434 #[ignore]
435 #[test]
436 fn test_flux_1_mini() {
437 set_hf_token(include_str!("../token.txt"));
438 run(Preset::Flux1Mini(Flux1MiniWeight::Q2_K));
439 }
440
441 #[ignore]
442 #[test]
443 fn test_chroma() {
444 set_hf_token(include_str!("../token.txt"));
445 run(Preset::Chroma(ChromaWeight::Q4_0));
446 }
447
448 #[ignore]
449 #[test]
450 fn test_nitro_sd_realism() {
451 run(Preset::NitroSDRealism(NitroSDRealismWeight::Q8_0));
452 }
453
454 #[ignore]
455 #[test]
456 fn test_nitro_sd_vibrant() {
457 run(Preset::NitroSDVibrant(NitroSDVibrantWeight::Q8_0));
458 }
459
460 #[ignore]
461 #[test]
462 fn test_diff_instruct_star() {
463 run(Preset::DiffInstructStar(DiffInstructStarWeight::Q8_0));
464 }
465
466 #[ignore]
467 #[test]
468 fn test_chroma_radiance() {
469 run(Preset::ChromaRadiance(ChromaRadianceWeight::Q8_0));
470 }
471
472 #[ignore]
473 #[test]
474 fn test_ssd_1b() {
475 run(Preset::SSD1B(SSD1BWeight::F8_E4M3));
476 }
477
478 #[ignore]
479 #[test]
480 fn test_flux_2_dev() {
481 set_hf_token(include_str!("../token.txt"));
482 run(Preset::Flux2Dev(Flux2Weight::Q2_K));
483 }
484
485 #[ignore]
486 #[test]
487 fn test_z_image_turbo() {
488 set_hf_token(include_str!("../token.txt"));
489 run(Preset::ZImageTurbo(ZImageTurboWeight::Q2_K));
490 }
491
492 #[ignore]
493 #[test]
494 fn test_qwen_image() {
495 run(Preset::QwenImage(QwenImageWeight::Q2_K));
496 }
497
498 #[ignore]
499 #[test]
500 fn test_ovis_image() {
501 set_hf_token(include_str!("../token.txt"));
502 run(Preset::OvisImage(OvisImageWeight::Q4_0));
503 }
504}