diffusion_rs/
preset.rs

1use derive_builder::Builder;
2use hf_hub::api::sync::ApiError;
3use strum::{EnumDiscriminants, EnumString, VariantNames};
4use subenum::subenum;
5
6use crate::{
7    api::{Config, ConfigBuilder, ConfigBuilderError, ModelConfig, ModelConfigBuilder},
8    preset_builder::{
9        chroma, chroma_radiance, diff_instruct_star, flux_1_dev, flux_1_mini, flux_1_schnell,
10        flux_2_dev, juggernaut_xl_11, nitro_sd_realism, nitro_sd_vibrant, ovis_image, qwen_image,
11        sd_turbo, sdxl_base_1_0, sdxl_turbo_1_0_fp16, ssd_1b, stable_diffusion_1_4,
12        stable_diffusion_1_5, stable_diffusion_2_1, stable_diffusion_3_5_large_fp16,
13        stable_diffusion_3_5_large_turbo_fp16, stable_diffusion_3_5_medium_fp16,
14        stable_diffusion_3_medium_fp16, z_image_turbo,
15    },
16};
17
18#[non_exhaustive]
19#[allow(non_camel_case_types)]
20#[subenum(
21    Flux1Weight(derive(Default)),
22    Flux1MiniWeight(derive(Default)),
23    ChromaWeight(derive(Default)),
24    NitroSDRealismWeight(derive(Default)),
25    NitroSDVibrantWeight(derive(Default)),
26    DiffInstructStarWeight(derive(Default)),
27    ChromaRadianceWeight(derive(Default)),
28    SSD1BWeight(derive(Default)),
29    Flux2Weight(derive(Default)),
30    ZImageTurboWeight(derive(Default)),
31    QwenImageWeight(derive(Default)),
32    OvisImageWeight(derive(Default))
33)]
34#[derive(Debug, Clone, Copy, EnumString, VariantNames)]
35#[strum(ascii_case_insensitive)]
36/// Model weight types
37pub enum WeightType {
38    #[subenum(Flux1MiniWeight)]
39    F32,
40    #[subenum(
41        NitroSDRealismWeight,
42        NitroSDVibrantWeight,
43        DiffInstructStarWeight,
44        SSD1BWeight
45    )]
46    F16,
47    #[subenum(
48        Flux1Weight,
49        ChromaWeight(default),
50        NitroSDRealismWeight,
51        NitroSDVibrantWeight,
52        DiffInstructStarWeight,
53        Flux2Weight,
54        ZImageTurboWeight,
55        QwenImageWeight,
56        OvisImageWeight(default)
57    )]
58    Q4_0,
59    #[subenum(Flux2Weight, QwenImageWeight)]
60    Q4_1,
61    #[subenum(
62        NitroSDRealismWeight,
63        NitroSDVibrantWeight,
64        DiffInstructStarWeight,
65        Flux2Weight,
66        ZImageTurboWeight,
67        QwenImageWeight
68    )]
69    Q5_0,
70    #[subenum(Flux2Weight, QwenImageWeight)]
71    Q5_1,
72    #[subenum(
73        Flux1Weight,
74        Flux1MiniWeight(default),
75        ChromaWeight,
76        NitroSDRealismWeight(default),
77        NitroSDVibrantWeight(default),
78        DiffInstructStarWeight(default),
79        ChromaRadianceWeight(default),
80        Flux2Weight,
81        ZImageTurboWeight,
82        QwenImageWeight,
83        OvisImageWeight
84    )]
85    Q8_0,
86    Q8_1,
87    #[subenum(
88        Flux1Weight(default),
89        Flux1MiniWeight,
90        NitroSDRealismWeight,
91        NitroSDVibrantWeight,
92        DiffInstructStarWeight,
93        Flux2Weight(default),
94        ZImageTurboWeight,
95        QwenImageWeight(default)
96    )]
97    Q2_K,
98    #[subenum(
99        Flux1Weight,
100        Flux1MiniWeight,
101        NitroSDRealismWeight,
102        NitroSDVibrantWeight,
103        DiffInstructStarWeight,
104        ZImageTurboWeight,
105        Flux2Weight,
106        QwenImageWeight
107    )]
108    Q3_K,
109    #[subenum(Flux1Weight, ZImageTurboWeight(default), Flux2Weight, QwenImageWeight)]
110    Q4_K,
111    #[subenum(Flux1MiniWeight, Flux2Weight, QwenImageWeight)]
112    Q5_K,
113    #[subenum(
114        Flux1MiniWeight,
115        NitroSDRealismWeight,
116        NitroSDVibrantWeight,
117        DiffInstructStarWeight,
118        Flux2Weight,
119        ZImageTurboWeight,
120        QwenImageWeight
121    )]
122    Q6_K,
123    Q8_K,
124    IQ2_XXS,
125    IQ2_XS,
126    IQ3_XXS,
127    IQ1_S,
128    IQ4_NL,
129    IQ3_S,
130    IQ2_S,
131    IQ4_XS,
132    I8,
133    I16,
134    I32,
135    I64,
136    F64,
137    IQ1_M,
138    #[subenum(
139        Flux1MiniWeight,
140        ChromaWeight,
141        ChromaRadianceWeight,
142        Flux2Weight,
143        ZImageTurboWeight,
144        QwenImageWeight,
145        OvisImageWeight
146    )]
147    BF16,
148    TQ1_0,
149    TQ2_0,
150    MXFP4,
151    #[subenum(SSD1BWeight(default), QwenImageWeight)]
152    F8_E4M3,
153}
154
155#[non_exhaustive]
156#[derive(Debug, Clone, Copy, EnumDiscriminants)]
157#[strum_discriminants(derive(EnumString, VariantNames), strum(ascii_case_insensitive))]
158/// Models ready to use
159pub enum Preset {
160    StableDiffusion1_4,
161    StableDiffusion1_5,
162    /// <https://huggingface.co/stabilityai/stable-diffusion-2-1> model.
163    ///  Vae-tiling enabled. 768x768.
164    StableDiffusion2_1,
165    /// Requires access rights to <https://huggingface.co/stabilityai/stable-diffusion-3-medium> providing a token via [crate::util::set_hf_token]
166    /// Vae-tiling enabled. 1024x1024. Enabled [crate::api::SampleMethod::EULER_SAMPLE_METHOD]. 30 steps.
167    StableDiffusion3MediumFp16,
168    /// Requires access rights to <https://huggingface.co/stabilityai/stable-diffusion-3.5-medium> providing a token via [crate::util::set_hf_token]
169    /// Vae-tiling enabled. 1024x1024. Enabled [crate::api::SampleMethod::EULER_SAMPLE_METHOD]. cfg_scale 4.5. 40 steps.
170    StableDiffusion3_5MediumFp16,
171    /// Requires access rights to <https://huggingface.co/stabilityai/stable-diffusion-3.5-large> providing a token via [crate::util::set_hf_token]
172    /// Vae-tiling enabled. 1024x1024. Enabled [crate::api::SampleMethod::EULER_SAMPLE_METHOD]. cfg_scale 4.5. 28 steps.
173    StableDiffusion3_5LargeFp16,
174    /// Requires access rights to <https://huggingface.co/stabilityai/stable-diffusion-3.5-large-turbo> providing a token via [crate::util::set_hf_token]
175    /// Vae-tiling enabled. 1024x1024. Enabled [crate::api::SampleMethod::EULER_SAMPLE_METHOD]. cfg_scale 0. 4 steps.
176    StableDiffusion3_5LargeTurboFp16,
177    SDXLBase1_0,
178    /// cfg_scale 1. guidance 0. 4 steps
179    SDTurbo,
180    /// cfg_scale 1. guidance 0. 4 steps
181    SDXLTurbo1_0Fp16,
182    /// Requires access rights to <https://huggingface.co/black-forest-labs/FLUX.1-dev> providing a token via [crate::util::set_hf_token]
183    /// Vae-tiling enabled. 1024x1024. Enabled [crate::api::SampleMethod::EULER_SAMPLE_METHOD]. 28 steps.
184    Flux1Dev(Flux1Weight),
185    /// Requires access rights to <https://huggingface.co/black-forest-labs/FLUX.1-schnell> providing a token via [crate::util::set_hf_token]
186    /// Vae-tiling enabled. 1024x1024. Enabled [crate::api::SampleMethod::EULER_SAMPLE_METHOD]. 4 steps.
187    Flux1Schnell(Flux1Weight),
188    /// A 3.2B param rectified flow transformer distilled from FLUX.1-dev <https://huggingface.co/TencentARC/flux-mini> <https://huggingface.co/HyperX-Sentience/Flux-Mini-GGUF>
189    /// Vae-tiling enabled. 512x512. Enabled [crate::api::SampleMethod::EULER_SAMPLE_METHOD]. cfg_scale 1. 20 steps.
190    Flux1Mini(Flux1MiniWeight),
191    /// Requires access rights to <https://huggingface.co/RunDiffusion/Juggernaut-XI-v11> providing a token via [crate::util::set_hf_token]
192    /// Vae-tiling enabled. 1024x1024. Enabled [crate::api::SampleMethod::DPM2_SAMPLE_METHOD]. guidance 6. 20 steps
193    JuggernautXL11,
194    /// Chroma is an 8.9B parameter model based on FLUX.1-schnell
195    /// Requires access rights to <https://huggingface.co/black-forest-labs/FLUX.1-dev> providing a token via [crate::util::set_hf_token]
196    /// Vae-tiling enabled. 512x512. Enabled [crate::api::SampleMethod::EULER_SAMPLE_METHOD]. cfg_scale 4. 20 steps
197    Chroma(ChromaWeight),
198    /// sgm_uniform scheduler. cfg_scale 1. timestep_shift 250. 1 steps. 1024x1024
199    NitroSDRealism(NitroSDRealismWeight),
200    /// sgm_uniform scheduler. cfg_scale 1. timestep_shift 500. 1 steps. 1024x1024
201    NitroSDVibrant(NitroSDVibrantWeight),
202    /// sgm_uniform scheduler. cfg_scale 1. timestep_shift 400. 1 steps. 1024x1024
203    DiffInstructStar(DiffInstructStarWeight),
204    /// Enabled [crate::api::SampleMethod::EULER_SAMPLE_METHOD]. cfg_scale 4.0. 20 steps.
205    ChromaRadiance(ChromaRadianceWeight),
206    /// cfg_scale 9.0. 20 steps. 1024x1024
207    SSD1B(SSD1BWeight),
208    /// Requires access rights to <https://huggingface.co/black-forest-labs/FLUX.2-dev> providing a token via [crate::util::set_hf_token]
209    /// Enabled [crate::api::SampleMethod::EULER_SAMPLE_METHOD]. cfg_scale 1.0. Flash attention enabled. Offload params to CPU enabled. 20 steps. 512x512. Vae-tiling enabled.
210    Flux2Dev(Flux2Weight),
211    /// Requires access rights to <https://huggingface.co/black-forest-labs/FLUX.1-schnell> providing a token via [crate::util::set_hf_token]
212    /// cfg_scale 1.0. 9 steps. Flash attention enabled. 1024x1024. Vae-tiling enabled.
213    ZImageTurbo(ZImageTurboWeight),
214    /// Enabled [crate::api::SampleMethod::EULER_SAMPLE_METHOD]. cfg_scale 2.5. flow_shift 3.0. Flash attention enabled. Offload params to CPU enabled. 20 steps. 1024x1024. Vae-tiling enabled.
215    QwenImage(QwenImageWeight),
216    /// Requires access rights to <https://huggingface.co/black-forest-labs/FLUX.1-schnel> providing a token via [crate::util::set_hf_token]
217    /// cfg_scale 5.0. Flash attention enabled. Offload params to CPU enabled. 20 steps. Vae-tiling enabled. 512x512.
218    OvisImage(OvisImageWeight),
219}
220
221impl Preset {
222    fn try_configs_builder(self) -> Result<(ConfigBuilder, ModelConfigBuilder), ApiError> {
223        #[allow(unused_mut)]
224        let mut preset = match self {
225            Preset::StableDiffusion1_4 => stable_diffusion_1_4(),
226            Preset::StableDiffusion1_5 => stable_diffusion_1_5(),
227            Preset::StableDiffusion2_1 => stable_diffusion_2_1(),
228            Preset::StableDiffusion3MediumFp16 => stable_diffusion_3_medium_fp16(),
229            Preset::SDXLBase1_0 => sdxl_base_1_0(),
230            Preset::Flux1Dev(sd_type_t) => flux_1_dev(sd_type_t),
231            Preset::Flux1Schnell(sd_type_t) => flux_1_schnell(sd_type_t),
232            Preset::SDTurbo => sd_turbo(),
233            Preset::SDXLTurbo1_0Fp16 => sdxl_turbo_1_0_fp16(),
234            Preset::StableDiffusion3_5LargeFp16 => stable_diffusion_3_5_large_fp16(),
235            Preset::StableDiffusion3_5MediumFp16 => stable_diffusion_3_5_medium_fp16(),
236            Preset::StableDiffusion3_5LargeTurboFp16 => stable_diffusion_3_5_large_turbo_fp16(),
237            Preset::JuggernautXL11 => juggernaut_xl_11(),
238            Preset::Flux1Mini(sd_type_t) => flux_1_mini(sd_type_t),
239            Preset::Chroma(sd_type_t) => chroma(sd_type_t),
240            Preset::NitroSDRealism(sd_type_t) => nitro_sd_realism(sd_type_t),
241            Preset::NitroSDVibrant(sd_type_t) => nitro_sd_vibrant(sd_type_t),
242            Preset::DiffInstructStar(sd_type_t) => diff_instruct_star(sd_type_t),
243            Preset::ChromaRadiance(sd_type_t) => chroma_radiance(sd_type_t),
244            Preset::SSD1B(sd_type_t) => ssd_1b(sd_type_t),
245            Preset::Flux2Dev(sd_type_t) => flux_2_dev(sd_type_t),
246            Preset::ZImageTurbo(sd_type_t) => z_image_turbo(sd_type_t),
247            Preset::QwenImage(sd_type_t) => qwen_image(sd_type_t),
248            Preset::OvisImage(sd_type_t) => ovis_image(sd_type_t),
249        };
250
251        // Metal workaround.
252        // See https://github.com/leejet/stable-diffusion.cpp/issues/1040#issuecomment-3623644576
253        #[cfg(feature = "metal")]
254        {
255            if let Ok((_, model_config)) = &mut preset {
256                model_config.clip_on_cpu(true);
257            };
258        }
259        preset
260    }
261}
262
263/// Configs tuple used by [crate::modifier]
264pub type ConfigsBuilder = (ConfigBuilder, ModelConfigBuilder);
265
266/// Returned by [PresetBuilder::build]
267pub type Configs = (Config, ModelConfig);
268
269/// Helper functions that modifies the [ConfigBuilder] See [crate::modifier]
270type Modifier = dyn FnOnce(ConfigsBuilder) -> Result<ConfigsBuilder, ApiError>;
271
272#[derive(Builder)]
273#[builder(
274    name = "PresetBuilder",
275    pattern = "owned",
276    setter(into),
277    build_fn(name = "internal_build", private, error = "ConfigBuilderError")
278)]
279/// Helper struct for [ConfigBuilder]
280pub struct PresetConfig {
281    prompt: String,
282    preset: Preset,
283    #[builder(private, default = "Vec::new()")]
284    modifiers: Vec<Box<Modifier>>,
285}
286
287impl PresetBuilder {
288    /// Add modifier that will apply in sequence
289    pub fn with_modifier<F>(mut self, f: F) -> Self
290    where
291        F: FnOnce(ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> + 'static,
292    {
293        if self.modifiers.is_none() {
294            self.modifiers = Some(Vec::new());
295        }
296        self.modifiers.as_mut().unwrap().push(Box::new(f));
297        self
298    }
299
300    pub fn build(self) -> Result<Configs, ConfigBuilderError> {
301        let preset = self.internal_build()?;
302        let configs: ConfigsBuilder = preset
303            .try_into()
304            .map_err(|err: ApiError| ConfigBuilderError::ValidationError(err.to_string()))?;
305        let config = configs.0.build()?;
306        let config_model = configs.1.build()?;
307
308        Ok((config, config_model))
309    }
310}
311
312impl TryFrom<PresetConfig> for ConfigsBuilder {
313    type Error = ApiError;
314
315    fn try_from(value: PresetConfig) -> Result<Self, Self::Error> {
316        let mut configs_builder = value.preset.try_configs_builder()?;
317        for m in value.modifiers {
318            configs_builder = m(configs_builder)?;
319        }
320        configs_builder.0.prompt(value.prompt);
321        Ok(configs_builder)
322    }
323}
324
325#[cfg(test)]
326mod tests {
327    use crate::{
328        api::gen_img,
329        preset::{
330            ChromaRadianceWeight, ChromaWeight, DiffInstructStarWeight, Flux1MiniWeight,
331            Flux1Weight, Flux2Weight, NitroSDRealismWeight, NitroSDVibrantWeight, OvisImageWeight,
332            QwenImageWeight, SSD1BWeight, ZImageTurboWeight,
333        },
334        util::set_hf_token,
335    };
336
337    use super::{Preset, PresetBuilder};
338    static PROMPT: &str = "a lovely dinosaur made by crochet";
339
340    fn run(preset: Preset) {
341        let (config, mut model_config) = PresetBuilder::default()
342            .preset(preset)
343            .prompt(PROMPT)
344            .build()
345            .unwrap();
346        gen_img(&config, &mut model_config).unwrap();
347    }
348
349    #[ignore]
350    #[test]
351    fn test_stable_diffusion_1_4() {
352        run(Preset::StableDiffusion1_4);
353    }
354
355    #[ignore]
356    #[test]
357    fn test_stable_diffusion_1_5() {
358        run(Preset::StableDiffusion1_5);
359    }
360
361    #[ignore]
362    #[test]
363    fn test_stable_diffusion_2_1() {
364        run(Preset::StableDiffusion2_1);
365    }
366
367    #[ignore]
368    #[test]
369    fn test_stable_diffusion_3_medium_fp16() {
370        set_hf_token(include_str!("../token.txt"));
371        run(Preset::StableDiffusion3MediumFp16);
372    }
373
374    #[ignore]
375    #[test]
376    fn test_sdxl_base_1_0() {
377        run(Preset::SDXLBase1_0);
378    }
379
380    #[ignore]
381    #[test]
382    fn test_flux_1_dev() {
383        set_hf_token(include_str!("../token.txt"));
384        run(Preset::Flux1Dev(Flux1Weight::Q2_K));
385    }
386
387    #[ignore]
388    #[test]
389    fn test_flux_1_schnell() {
390        set_hf_token(include_str!("../token.txt"));
391        run(Preset::Flux1Schnell(Flux1Weight::Q2_K));
392    }
393
394    #[ignore]
395    #[test]
396    fn test_sd_turbo() {
397        run(Preset::SDTurbo);
398    }
399
400    #[ignore]
401    #[test]
402    fn test_sdxl_turbo_1_0_fp16() {
403        run(Preset::SDXLTurbo1_0Fp16);
404    }
405
406    #[ignore]
407    #[test]
408    fn test_stable_diffusion_3_5_medium_fp16() {
409        set_hf_token(include_str!("../token.txt"));
410        run(Preset::StableDiffusion3_5MediumFp16);
411    }
412
413    #[ignore]
414    #[test]
415    fn test_stable_diffusion_3_5_large_fp16() {
416        set_hf_token(include_str!("../token.txt"));
417        run(Preset::StableDiffusion3_5LargeFp16);
418    }
419
420    #[ignore]
421    #[test]
422    fn test_stable_diffusion_3_5_large_turbo_fp16() {
423        set_hf_token(include_str!("../token.txt"));
424        run(Preset::StableDiffusion3_5LargeTurboFp16);
425    }
426
427    #[ignore]
428    #[test]
429    fn test_juggernaut_xl_11() {
430        set_hf_token(include_str!("../token.txt"));
431        run(Preset::JuggernautXL11);
432    }
433
434    #[ignore]
435    #[test]
436    fn test_flux_1_mini() {
437        set_hf_token(include_str!("../token.txt"));
438        run(Preset::Flux1Mini(Flux1MiniWeight::Q2_K));
439    }
440
441    #[ignore]
442    #[test]
443    fn test_chroma() {
444        set_hf_token(include_str!("../token.txt"));
445        run(Preset::Chroma(ChromaWeight::Q4_0));
446    }
447
448    #[ignore]
449    #[test]
450    fn test_nitro_sd_realism() {
451        run(Preset::NitroSDRealism(NitroSDRealismWeight::Q8_0));
452    }
453
454    #[ignore]
455    #[test]
456    fn test_nitro_sd_vibrant() {
457        run(Preset::NitroSDVibrant(NitroSDVibrantWeight::Q8_0));
458    }
459
460    #[ignore]
461    #[test]
462    fn test_diff_instruct_star() {
463        run(Preset::DiffInstructStar(DiffInstructStarWeight::Q8_0));
464    }
465
466    #[ignore]
467    #[test]
468    fn test_chroma_radiance() {
469        run(Preset::ChromaRadiance(ChromaRadianceWeight::Q8_0));
470    }
471
472    #[ignore]
473    #[test]
474    fn test_ssd_1b() {
475        run(Preset::SSD1B(SSD1BWeight::F8_E4M3));
476    }
477
478    #[ignore]
479    #[test]
480    fn test_flux_2_dev() {
481        set_hf_token(include_str!("../token.txt"));
482        run(Preset::Flux2Dev(Flux2Weight::Q2_K));
483    }
484
485    #[ignore]
486    #[test]
487    fn test_z_image_turbo() {
488        set_hf_token(include_str!("../token.txt"));
489        run(Preset::ZImageTurbo(ZImageTurboWeight::Q2_K));
490    }
491
492    #[ignore]
493    #[test]
494    fn test_qwen_image() {
495        run(Preset::QwenImage(QwenImageWeight::Q2_K));
496    }
497
498    #[ignore]
499    #[test]
500    fn test_ovis_image() {
501        set_hf_token(include_str!("../token.txt"));
502        run(Preset::OvisImage(OvisImageWeight::Q4_0));
503    }
504}