diffusion_rs/
preset.rs

1use derive_builder::Builder;
2use hf_hub::api::sync::ApiError;
3use subenum::subenum;
4
5use crate::{
6    api::{Config, ConfigBuilder, ConfigBuilderError, ModelConfig, ModelConfigBuilder},
7    preset_builder::{
8        chroma, chroma_radiance, diff_instruct_star, flux_1_dev, flux_1_mini, flux_1_schnell,
9        flux_2_dev, juggernaut_xl_11, nitro_sd_realism, nitro_sd_vibrant, ovis_image, qwen_image,
10        sd_turbo, sdxl_base_1_0, sdxl_turbo_1_0_fp16, ssd_1b, stable_diffusion_1_4,
11        stable_diffusion_1_5, stable_diffusion_2_1, stable_diffusion_3_5_large_fp16,
12        stable_diffusion_3_5_large_turbo_fp16, stable_diffusion_3_5_medium_fp16,
13        stable_diffusion_3_medium_fp16, z_image_turbo,
14    },
15};
16
17#[non_exhaustive]
18#[allow(non_camel_case_types)]
19#[subenum(
20    Flux1Weight,
21    Flux1MiniWeight,
22    ChromaWeight,
23    NitroSDRealismWeight,
24    NitroSDVibrantWeight,
25    DiffInstructStarWeight,
26    ChromaRadianceWeight,
27    SSD1BWeight,
28    Flux2Weight,
29    ZImageTurboWeight,
30    QwenImageWeight,
31    OvisImageWeight
32)]
33#[derive(Debug, Clone, Copy)]
34/// Model weight types
35pub enum WeightType {
36    #[subenum(Flux1MiniWeight)]
37    F32,
38    #[subenum(
39        NitroSDRealismWeight,
40        NitroSDVibrantWeight,
41        DiffInstructStarWeight,
42        SSD1BWeight
43    )]
44    F16,
45    #[subenum(
46        Flux1Weight,
47        ChromaWeight,
48        NitroSDRealismWeight,
49        NitroSDVibrantWeight,
50        DiffInstructStarWeight,
51        Flux2Weight,
52        ZImageTurboWeight,
53        QwenImageWeight,
54        OvisImageWeight
55    )]
56    Q4_0,
57    #[subenum(Flux2Weight, QwenImageWeight)]
58    Q4_1,
59    #[subenum(
60        NitroSDRealismWeight,
61        NitroSDVibrantWeight,
62        DiffInstructStarWeight,
63        Flux2Weight,
64        ZImageTurboWeight,
65        QwenImageWeight
66    )]
67    Q5_0,
68    #[subenum(Flux2Weight, QwenImageWeight)]
69    Q5_1,
70    #[subenum(
71        Flux1Weight,
72        Flux1MiniWeight,
73        ChromaWeight,
74        NitroSDRealismWeight,
75        NitroSDVibrantWeight,
76        DiffInstructStarWeight,
77        ChromaRadianceWeight,
78        Flux2Weight,
79        ZImageTurboWeight,
80        QwenImageWeight,
81        OvisImageWeight
82    )]
83    Q8_0,
84    Q8_1,
85    #[subenum(
86        Flux1Weight,
87        Flux1MiniWeight,
88        NitroSDRealismWeight,
89        NitroSDVibrantWeight,
90        DiffInstructStarWeight,
91        Flux2Weight,
92        ZImageTurboWeight,
93        QwenImageWeight
94    )]
95    Q2_K,
96    #[subenum(
97        Flux1Weight,
98        Flux1MiniWeight,
99        NitroSDRealismWeight,
100        NitroSDVibrantWeight,
101        DiffInstructStarWeight,
102        ZImageTurboWeight,
103        Flux2Weight,
104        QwenImageWeight
105    )]
106    Q3_K,
107    #[subenum(Flux1Weight, ZImageTurboWeight, Flux2Weight, QwenImageWeight)]
108    Q4_K,
109    #[subenum(Flux1MiniWeight, Flux2Weight, QwenImageWeight)]
110    Q5_K,
111    #[subenum(
112        Flux1MiniWeight,
113        NitroSDRealismWeight,
114        NitroSDVibrantWeight,
115        DiffInstructStarWeight,
116        Flux2Weight,
117        ZImageTurboWeight,
118        QwenImageWeight
119    )]
120    Q6_K,
121    Q8_K,
122    IQ2_XXS,
123    IQ2_XS,
124    IQ3_XXS,
125    IQ1_S,
126    IQ4_NL,
127    IQ3_S,
128    IQ2_S,
129    IQ4_XS,
130    I8,
131    I16,
132    I32,
133    I64,
134    F64,
135    IQ1_M,
136    #[subenum(
137        Flux1MiniWeight,
138        ChromaWeight,
139        ChromaRadianceWeight,
140        Flux2Weight,
141        ZImageTurboWeight,
142        QwenImageWeight,
143        OvisImageWeight
144    )]
145    BF16,
146    TQ1_0,
147    TQ2_0,
148    MXFP4,
149    #[subenum(SSD1BWeight, QwenImageWeight)]
150    F8_E4M3,
151}
152
153#[non_exhaustive]
154#[derive(Debug, Clone, Copy)]
155/// Models ready to use
156pub enum Preset {
157    StableDiffusion1_4,
158    StableDiffusion1_5,
159    /// <https://huggingface.co/stabilityai/stable-diffusion-2-1> model.
160    ///  Vae-tiling enabled. 768x768.
161    StableDiffusion2_1,
162    /// Requires access rights to <https://huggingface.co/stabilityai/stable-diffusion-3-medium> providing a token via [crate::util::set_hf_token]
163    /// Vae-tiling enabled. 1024x1024. Enabled [crate::api::SampleMethod::EULER_SAMPLE_METHOD]. 30 steps.
164    StableDiffusion3MediumFp16,
165    /// Requires access rights to <https://huggingface.co/stabilityai/stable-diffusion-3.5-medium> providing a token via [crate::util::set_hf_token]
166    /// Vae-tiling enabled. 1024x1024. Enabled [crate::api::SampleMethod::EULER_SAMPLE_METHOD]. cfg_scale 4.5. 40 steps.
167    StableDiffusion3_5MediumFp16,
168    /// Requires access rights to <https://huggingface.co/stabilityai/stable-diffusion-3.5-large> providing a token via [crate::util::set_hf_token]
169    /// Vae-tiling enabled. 1024x1024. Enabled [crate::api::SampleMethod::EULER_SAMPLE_METHOD]. cfg_scale 4.5. 28 steps.
170    StableDiffusion3_5LargeFp16,
171    /// Requires access rights to <https://huggingface.co/stabilityai/stable-diffusion-3.5-large-turbo> providing a token via [crate::util::set_hf_token]
172    /// Vae-tiling enabled. 1024x1024. Enabled [crate::api::SampleMethod::EULER_SAMPLE_METHOD]. cfg_scale 0. 4 steps.
173    StableDiffusion3_5LargeTurboFp16,
174    SDXLBase1_0,
175    /// cfg_scale 1. guidance 0. 4 steps
176    SDTurbo,
177    /// cfg_scale 1. guidance 0. 4 steps
178    SDXLTurbo1_0Fp16,
179    /// Requires access rights to <https://huggingface.co/black-forest-labs/FLUX.1-dev> providing a token via [crate::util::set_hf_token]
180    /// Vae-tiling enabled. 1024x1024. Enabled [crate::api::SampleMethod::EULER_SAMPLE_METHOD]. 28 steps.
181    Flux1Dev(Flux1Weight),
182    /// Requires access rights to <https://huggingface.co/black-forest-labs/FLUX.1-schnell> providing a token via [crate::util::set_hf_token]
183    /// Vae-tiling enabled. 1024x1024. Enabled [crate::api::SampleMethod::EULER_SAMPLE_METHOD]. 4 steps.
184    Flux1Schnell(Flux1Weight),
185    /// A 3.2B param rectified flow transformer distilled from FLUX.1-dev <https://huggingface.co/TencentARC/flux-mini> <https://huggingface.co/HyperX-Sentience/Flux-Mini-GGUF>
186    /// Vae-tiling enabled. 512x512. Enabled [crate::api::SampleMethod::EULER_SAMPLE_METHOD]. cfg_scale 1. 20 steps.
187    Flux1Mini(Flux1MiniWeight),
188    /// Requires access rights to <https://huggingface.co/RunDiffusion/Juggernaut-XI-v11> providing a token via [crate::util::set_hf_token]
189    /// Vae-tiling enabled. 1024x1024. Enabled [crate::api::SampleMethod::DPM2_SAMPLE_METHOD]. guidance 6. 20 steps
190    JuggernautXL11,
191    /// Chroma is an 8.9B parameter model based on FLUX.1-schnell
192    /// Requires access rights to <https://huggingface.co/black-forest-labs/FLUX.1-dev> providing a token via [crate::util::set_hf_token]
193    /// Vae-tiling enabled. 512x512. Enabled [crate::api::SampleMethod::EULER_SAMPLE_METHOD]. cfg_scale 4. 20 steps
194    Chroma(ChromaWeight),
195    /// sgm_uniform scheduler. cfg_scale 1. timestep_shift 250. 1 steps. 1024x1024
196    NitroSDRealism(NitroSDRealismWeight),
197    /// sgm_uniform scheduler. cfg_scale 1. timestep_shift 500. 1 steps. 1024x1024
198    NitroSDVibrant(NitroSDVibrantWeight),
199    /// sgm_uniform scheduler. cfg_scale 1. timestep_shift 400. 1 steps. 1024x1024
200    DiffInstructStar(DiffInstructStarWeight),
201    /// Enabled [crate::api::SampleMethod::EULER_SAMPLE_METHOD]. cfg_scale 4.0. 20 steps.
202    ChromaRadiance(ChromaRadianceWeight),
203    /// cfg_scale 9.0. 20 steps. 1024x1024
204    SSD1B(SSD1BWeight),
205    /// Requires access rights to <https://huggingface.co/black-forest-labs/FLUX.2-dev> providing a token via [crate::util::set_hf_token]
206    /// Enabled [crate::api::SampleMethod::EULER_SAMPLE_METHOD]. cfg_scale 1.0. Flash attention enabled. Offload params to CPU enabled. 20 steps. 512x512. Vae-tiling enabled.
207    Flux2Dev(Flux2Weight),
208    /// Requires access rights to <https://huggingface.co/black-forest-labs/FLUX.1-schnell> providing a token via [crate::util::set_hf_token]
209    /// cfg_scale 1.0. 9 steps. Flash attention enabled. 1024x1024. Vae-tiling enabled.
210    ZImageTurbo(ZImageTurboWeight),
211    /// Enabled [crate::api::SampleMethod::EULER_SAMPLE_METHOD]. cfg_scale 2.5. flow_shift 3.0. Flash attention enabled. Offload params to CPU enabled. 20 steps. 1024x1024. Vae-tiling enabled.
212    QwenImage(QwenImageWeight),
213    /// Requires access rights to <https://huggingface.co/black-forest-labs/FLUX.1-schnel> providing a token via [crate::util::set_hf_token]
214    /// cfg_scale 5.0. Flash attention enabled. Offload params to CPU enabled. 20 steps. Vae-tiling enabled. 512x512.
215    OvisImage(OvisImageWeight),
216}
217
218impl Preset {
219    fn try_configs_builder(self) -> Result<(ConfigBuilder, ModelConfigBuilder), ApiError> {
220        match self {
221            Preset::StableDiffusion1_4 => stable_diffusion_1_4(),
222            Preset::StableDiffusion1_5 => stable_diffusion_1_5(),
223            Preset::StableDiffusion2_1 => stable_diffusion_2_1(),
224            Preset::StableDiffusion3MediumFp16 => stable_diffusion_3_medium_fp16(),
225            Preset::SDXLBase1_0 => sdxl_base_1_0(),
226            Preset::Flux1Dev(sd_type_t) => flux_1_dev(sd_type_t),
227            Preset::Flux1Schnell(sd_type_t) => flux_1_schnell(sd_type_t),
228            Preset::SDTurbo => sd_turbo(),
229            Preset::SDXLTurbo1_0Fp16 => sdxl_turbo_1_0_fp16(),
230            Preset::StableDiffusion3_5LargeFp16 => stable_diffusion_3_5_large_fp16(),
231            Preset::StableDiffusion3_5MediumFp16 => stable_diffusion_3_5_medium_fp16(),
232            Preset::StableDiffusion3_5LargeTurboFp16 => stable_diffusion_3_5_large_turbo_fp16(),
233            Preset::JuggernautXL11 => juggernaut_xl_11(),
234            Preset::Flux1Mini(sd_type_t) => flux_1_mini(sd_type_t),
235            Preset::Chroma(sd_type_t) => chroma(sd_type_t),
236            Preset::NitroSDRealism(sd_type_t) => nitro_sd_realism(sd_type_t),
237            Preset::NitroSDVibrant(sd_type_t) => nitro_sd_vibrant(sd_type_t),
238            Preset::DiffInstructStar(sd_type_t) => diff_instruct_star(sd_type_t),
239            Preset::ChromaRadiance(sd_type_t) => chroma_radiance(sd_type_t),
240            Preset::SSD1B(sd_type_t) => ssd_1b(sd_type_t),
241            Preset::Flux2Dev(sd_type_t) => flux_2_dev(sd_type_t),
242            Preset::ZImageTurbo(sd_type_t) => z_image_turbo(sd_type_t),
243            Preset::QwenImage(sd_type_t) => qwen_image(sd_type_t),
244            Preset::OvisImage(sd_type_t) => ovis_image(sd_type_t),
245        }
246    }
247}
248
249/// Configs tuple used by [crate::modifier]
250pub type ConfigsBuilder = (ConfigBuilder, ModelConfigBuilder);
251
252/// Returned by [PresetBuilder::build]
253pub type Configs = (Config, ModelConfig);
254
255/// Helper functions that modifies the [ConfigBuilder] See [crate::modifier]
256pub type Modifier = fn(ConfigsBuilder) -> Result<ConfigsBuilder, ApiError>;
257
258#[derive(Debug, Clone, Builder)]
259#[builder(
260    name = "PresetBuilder",
261    setter(into),
262    build_fn(name = "internal_build", private, error = "ConfigBuilderError")
263)]
264/// Helper struct for [ConfigBuilder]
265pub struct PresetConfig {
266    prompt: String,
267    preset: Preset,
268    #[builder(private, default = "Vec::new()")]
269    modifiers: Vec<Modifier>,
270}
271
272impl PresetBuilder {
273    /// Add modifier that will apply in sequence
274    pub fn with_modifier(&mut self, f: Modifier) -> &mut Self {
275        if self.modifiers.is_none() {
276            self.modifiers = Some(Vec::new());
277        }
278        self.modifiers.as_mut().unwrap().push(f);
279        self
280    }
281
282    pub fn build(&mut self) -> Result<Configs, ConfigBuilderError> {
283        let preset = self.internal_build()?;
284        let configs: ConfigsBuilder = preset
285            .try_into()
286            .map_err(|err: ApiError| ConfigBuilderError::ValidationError(err.to_string()))?;
287        let config = configs.0.build()?;
288        let config_model = configs.1.build()?;
289
290        Ok((config, config_model))
291    }
292}
293
294impl TryFrom<PresetConfig> for ConfigsBuilder {
295    type Error = ApiError;
296
297    fn try_from(value: PresetConfig) -> Result<Self, Self::Error> {
298        let mut configs_builder = value.preset.try_configs_builder()?;
299        for m in value.modifiers {
300            configs_builder = m(configs_builder)?;
301        }
302        configs_builder.0.prompt(value.prompt);
303        Ok(configs_builder)
304    }
305}
306
307#[cfg(test)]
308mod tests {
309    use crate::{
310        api::gen_img,
311        preset::{
312            ChromaRadianceWeight, ChromaWeight, DiffInstructStarWeight, Flux1MiniWeight,
313            Flux1Weight, Flux2Weight, NitroSDRealismWeight, NitroSDVibrantWeight, OvisImageWeight,
314            QwenImageWeight, SSD1BWeight, ZImageTurboWeight,
315        },
316        util::set_hf_token,
317    };
318
319    use super::{Preset, PresetBuilder};
320    static PROMPT: &str = "a lovely dynosaur made by crochet";
321
322    fn run(preset: Preset) {
323        let (config, mut model_config) = PresetBuilder::default()
324            .preset(preset)
325            .prompt(PROMPT)
326            .build()
327            .unwrap();
328        gen_img(&config, &mut model_config).unwrap();
329    }
330
331    #[ignore]
332    #[test]
333    fn test_stable_diffusion_1_4() {
334        run(Preset::StableDiffusion1_4);
335    }
336
337    #[ignore]
338    #[test]
339    fn test_stable_diffusion_1_5() {
340        run(Preset::StableDiffusion1_5);
341    }
342
343    #[ignore]
344    #[test]
345    fn test_stable_diffusion_2_1() {
346        run(Preset::StableDiffusion2_1);
347    }
348
349    #[ignore]
350    #[test]
351    fn test_stable_diffusion_3_medium_fp16() {
352        set_hf_token(include_str!("../token.txt"));
353        run(Preset::StableDiffusion3MediumFp16);
354    }
355
356    #[ignore]
357    #[test]
358    fn test_sdxl_base_1_0() {
359        run(Preset::SDXLBase1_0);
360    }
361
362    #[ignore]
363    #[test]
364    fn test_flux_1_dev() {
365        set_hf_token(include_str!("../token.txt"));
366        run(Preset::Flux1Dev(Flux1Weight::Q2_K));
367    }
368
369    #[ignore]
370    #[test]
371    fn test_flux_1_schnell() {
372        set_hf_token(include_str!("../token.txt"));
373        run(Preset::Flux1Schnell(Flux1Weight::Q2_K));
374    }
375
376    #[ignore]
377    #[test]
378    fn test_sd_turbo() {
379        run(Preset::SDTurbo);
380    }
381
382    #[ignore]
383    #[test]
384    fn test_sdxl_turbo_1_0_fp16() {
385        run(Preset::SDXLTurbo1_0Fp16);
386    }
387
388    #[ignore]
389    #[test]
390    fn test_stable_diffusion_3_5_medium_fp16() {
391        set_hf_token(include_str!("../token.txt"));
392        run(Preset::StableDiffusion3_5MediumFp16);
393    }
394
395    #[ignore]
396    #[test]
397    fn test_stable_diffusion_3_5_large_fp16() {
398        set_hf_token(include_str!("../token.txt"));
399        run(Preset::StableDiffusion3_5LargeFp16);
400    }
401
402    #[ignore]
403    #[test]
404    fn test_stable_diffusion_3_5_large_turbo_fp16() {
405        set_hf_token(include_str!("../token.txt"));
406        run(Preset::StableDiffusion3_5LargeTurboFp16);
407    }
408
409    #[ignore]
410    #[test]
411    fn test_juggernaut_xl_11() {
412        set_hf_token(include_str!("../token.txt"));
413        run(Preset::JuggernautXL11);
414    }
415
416    #[ignore]
417    #[test]
418    fn test_flux_1_mini() {
419        set_hf_token(include_str!("../token.txt"));
420        run(Preset::Flux1Mini(Flux1MiniWeight::Q2_K));
421    }
422
423    #[ignore]
424    #[test]
425    fn test_chroma() {
426        set_hf_token(include_str!("../token.txt"));
427        run(Preset::Chroma(ChromaWeight::Q4_0));
428    }
429
430    #[ignore]
431    #[test]
432    fn test_nitro_sd_realism() {
433        run(Preset::NitroSDRealism(NitroSDRealismWeight::Q8_0));
434    }
435
436    #[ignore]
437    #[test]
438    fn test_nitro_sd_vibrant() {
439        run(Preset::NitroSDVibrant(NitroSDVibrantWeight::Q8_0));
440    }
441
442    #[ignore]
443    #[test]
444    fn test_diff_instruct_star() {
445        run(Preset::DiffInstructStar(DiffInstructStarWeight::Q8_0));
446    }
447
448    #[ignore]
449    #[test]
450    fn test_chroma_radiance() {
451        run(Preset::ChromaRadiance(ChromaRadianceWeight::Q8_0));
452    }
453
454    #[ignore]
455    #[test]
456    fn test_ssd_1b() {
457        run(Preset::SSD1B(SSD1BWeight::F8_E4M3));
458    }
459
460    #[ignore]
461    #[test]
462    fn test_flux_2_dev() {
463        set_hf_token(include_str!("../token.txt"));
464        run(Preset::Flux2Dev(Flux2Weight::Q2_K));
465    }
466
467    #[ignore]
468    #[test]
469    fn test_z_image_turbo() {
470        set_hf_token(include_str!("../token.txt"));
471        run(Preset::ZImageTurbo(ZImageTurboWeight::Q2_K));
472    }
473
474    #[ignore]
475    #[test]
476    fn qwen_image() {
477        run(Preset::QwenImage(QwenImageWeight::Q2_K));
478    }
479
480    #[ignore]
481    #[test]
482    fn test_ovis_image() {
483        set_hf_token(include_str!("../token.txt"));
484        run(Preset::OvisImage(OvisImageWeight::Q4_0));
485    }
486}