diffusion_rs/
preset.rs

1use derive_builder::Builder;
2use hf_hub::api::sync::ApiError;
3use strum::{EnumDiscriminants, EnumString, VariantNames};
4use subenum::subenum;
5
6use crate::{
7    api::{Config, ConfigBuilder, ConfigBuilderError, ModelConfig, ModelConfigBuilder},
8    preset_builder::{
9        chroma, chroma_radiance, diff_instruct_star, dream_shaper_xl_2_1_turbo, flux_1_dev,
10        flux_1_mini, flux_1_schnell, flux_2_dev, juggernaut_xl_11, nitro_sd_realism,
11        nitro_sd_vibrant, ovis_image, qwen_image, sd_turbo, sdxl_base_1_0, sdxl_turbo_1_0, ssd_1b,
12        stable_diffusion_1_4, stable_diffusion_1_5, stable_diffusion_2_1,
13        stable_diffusion_3_5_large, stable_diffusion_3_5_large_turbo, stable_diffusion_3_5_medium,
14        stable_diffusion_3_medium, twinflow_z_image_turbo, z_image_turbo,
15    },
16};
17
18#[non_exhaustive]
19#[allow(non_camel_case_types)]
20#[subenum(
21    Flux1Weight(derive(Default)),
22    Flux1MiniWeight(derive(Default)),
23    ChromaWeight(derive(Default)),
24    NitroSDRealismWeight(derive(Default)),
25    NitroSDVibrantWeight(derive(Default)),
26    DiffInstructStarWeight(derive(Default)),
27    ChromaRadianceWeight(derive(Default)),
28    SSD1BWeight(derive(Default)),
29    Flux2Weight(derive(Default)),
30    ZImageTurboWeight(derive(Default)),
31    QwenImageWeight(derive(Default)),
32    OvisImageWeight(derive(Default)),
33    TwinFlowZImageTurboExpWeight(derive(Default))
34)]
35#[derive(Debug, Clone, Copy, EnumString, VariantNames)]
36#[strum(ascii_case_insensitive)]
37/// Model weight types
38pub enum WeightType {
39    #[subenum(Flux1MiniWeight)]
40    F32,
41    #[subenum(
42        NitroSDRealismWeight,
43        NitroSDVibrantWeight,
44        DiffInstructStarWeight,
45        SSD1BWeight
46    )]
47    F16,
48    #[subenum(
49        Flux1Weight,
50        ChromaWeight(default),
51        NitroSDRealismWeight,
52        NitroSDVibrantWeight,
53        DiffInstructStarWeight,
54        Flux2Weight,
55        ZImageTurboWeight,
56        QwenImageWeight,
57        OvisImageWeight(default),
58        TwinFlowZImageTurboExpWeight(default)
59    )]
60    Q4_0,
61    #[subenum(Flux2Weight, QwenImageWeight)]
62    Q4_1,
63    #[subenum(
64        NitroSDRealismWeight,
65        NitroSDVibrantWeight,
66        DiffInstructStarWeight,
67        Flux2Weight,
68        ZImageTurboWeight,
69        QwenImageWeight,
70        TwinFlowZImageTurboExpWeight
71    )]
72    Q5_0,
73    #[subenum(Flux2Weight, QwenImageWeight)]
74    Q5_1,
75    #[subenum(
76        Flux1Weight,
77        Flux1MiniWeight(default),
78        ChromaWeight,
79        NitroSDRealismWeight(default),
80        NitroSDVibrantWeight(default),
81        DiffInstructStarWeight(default),
82        ChromaRadianceWeight(default),
83        Flux2Weight,
84        ZImageTurboWeight,
85        QwenImageWeight,
86        OvisImageWeight,
87        TwinFlowZImageTurboExpWeight
88    )]
89    Q8_0,
90    Q8_1,
91    #[subenum(
92        Flux1Weight(default),
93        Flux1MiniWeight,
94        NitroSDRealismWeight,
95        NitroSDVibrantWeight,
96        DiffInstructStarWeight,
97        Flux2Weight(default),
98        ZImageTurboWeight,
99        QwenImageWeight(default)
100    )]
101    Q2_K,
102    #[subenum(
103        Flux1Weight,
104        Flux1MiniWeight,
105        NitroSDRealismWeight,
106        NitroSDVibrantWeight,
107        DiffInstructStarWeight,
108        ZImageTurboWeight,
109        Flux2Weight,
110        QwenImageWeight,
111        TwinFlowZImageTurboExpWeight
112    )]
113    Q3_K,
114    #[subenum(Flux1Weight, ZImageTurboWeight(default), Flux2Weight, QwenImageWeight)]
115    Q4_K,
116    #[subenum(Flux1MiniWeight, Flux2Weight, QwenImageWeight)]
117    Q5_K,
118    #[subenum(
119        Flux1MiniWeight,
120        NitroSDRealismWeight,
121        NitroSDVibrantWeight,
122        DiffInstructStarWeight,
123        Flux2Weight,
124        ZImageTurboWeight,
125        QwenImageWeight,
126        TwinFlowZImageTurboExpWeight
127    )]
128    Q6_K,
129    Q8_K,
130    IQ2_XXS,
131    IQ2_XS,
132    IQ3_XXS,
133    IQ1_S,
134    IQ4_NL,
135    IQ3_S,
136    IQ2_S,
137    IQ4_XS,
138    I8,
139    I16,
140    I32,
141    I64,
142    F64,
143    IQ1_M,
144    #[subenum(
145        Flux1MiniWeight,
146        ChromaWeight,
147        ChromaRadianceWeight,
148        Flux2Weight,
149        ZImageTurboWeight,
150        QwenImageWeight,
151        OvisImageWeight,
152        TwinFlowZImageTurboExpWeight
153    )]
154    BF16,
155    TQ1_0,
156    TQ2_0,
157    MXFP4,
158    #[subenum(SSD1BWeight(default), QwenImageWeight)]
159    F8_E4M3,
160}
161
162#[non_exhaustive]
163#[derive(Debug, Clone, Copy, EnumDiscriminants)]
164#[strum_discriminants(derive(EnumString, VariantNames), strum(ascii_case_insensitive))]
165/// Models ready to use
166pub enum Preset {
167    StableDiffusion1_4,
168    StableDiffusion1_5,
169    /// <https://huggingface.co/stabilityai/stable-diffusion-2-1> model.
170    ///  Vae-tiling enabled. 768x768.
171    StableDiffusion2_1,
172    /// Requires access rights to <https://huggingface.co/stabilityai/stable-diffusion-3-medium> providing a token via [crate::util::set_hf_token]
173    /// Vae-tiling enabled. 1024x1024. Enabled [crate::api::SampleMethod::EULER_SAMPLE_METHOD]. 30 steps.
174    StableDiffusion3Medium,
175    /// Requires access rights to <https://huggingface.co/stabilityai/stable-diffusion-3.5-medium> providing a token via [crate::util::set_hf_token]
176    /// Vae-tiling enabled. 1024x1024. Enabled [crate::api::SampleMethod::EULER_SAMPLE_METHOD]. cfg_scale 4.5. 40 steps.
177    StableDiffusion3_5Medium,
178    /// Requires access rights to <https://huggingface.co/stabilityai/stable-diffusion-3.5-large> providing a token via [crate::util::set_hf_token]
179    /// Vae-tiling enabled. 1024x1024. Enabled [crate::api::SampleMethod::EULER_SAMPLE_METHOD]. cfg_scale 4.5. 28 steps.
180    StableDiffusion3_5Large,
181    /// Requires access rights to <https://huggingface.co/stabilityai/stable-diffusion-3.5-large-turbo> providing a token via [crate::util::set_hf_token]
182    /// Vae-tiling enabled. 1024x1024. Enabled [crate::api::SampleMethod::EULER_SAMPLE_METHOD]. cfg_scale 0. 4 steps.
183    StableDiffusion3_5LargeTurbo,
184    SDXLBase1_0,
185    /// cfg_scale 1. guidance 0. 4 steps
186    SDTurbo,
187    /// cfg_scale 1. guidance 0. 4 steps
188    SDXLTurbo1_0,
189    /// Requires access rights to <https://huggingface.co/black-forest-labs/FLUX.1-dev> providing a token via [crate::util::set_hf_token]
190    /// Vae-tiling enabled. 1024x1024. Enabled [crate::api::SampleMethod::EULER_SAMPLE_METHOD]. 28 steps.
191    Flux1Dev(Flux1Weight),
192    /// Requires access rights to <https://huggingface.co/black-forest-labs/FLUX.1-schnell> providing a token via [crate::util::set_hf_token]
193    /// Vae-tiling enabled. 1024x1024. Enabled [crate::api::SampleMethod::EULER_SAMPLE_METHOD]. 4 steps.
194    Flux1Schnell(Flux1Weight),
195    /// A 3.2B param rectified flow transformer distilled from FLUX.1-dev <https://huggingface.co/TencentARC/flux-mini> <https://huggingface.co/HyperX-Sentience/Flux-Mini-GGUF>
196    /// Vae-tiling enabled. 512x512. Enabled [crate::api::SampleMethod::EULER_SAMPLE_METHOD]. cfg_scale 1. 20 steps.
197    Flux1Mini(Flux1MiniWeight),
198    /// Requires access rights to <https://huggingface.co/RunDiffusion/Juggernaut-XI-v11> providing a token via [crate::util::set_hf_token]
199    /// Vae-tiling enabled. 1024x1024. Enabled [crate::api::SampleMethod::DPM2_SAMPLE_METHOD]. guidance 6. 20 steps
200    JuggernautXL11,
201    /// Chroma is an 8.9B parameter model based on FLUX.1-schnell
202    /// Requires access rights to <https://huggingface.co/black-forest-labs/FLUX.1-dev> providing a token via [crate::util::set_hf_token]
203    /// Vae-tiling enabled. 512x512. Enabled [crate::api::SampleMethod::EULER_SAMPLE_METHOD]. cfg_scale 4. 20 steps
204    Chroma(ChromaWeight),
205    /// sgm_uniform scheduler. cfg_scale 1. timestep_shift 250. 1 steps. 1024x1024
206    NitroSDRealism(NitroSDRealismWeight),
207    /// sgm_uniform scheduler. cfg_scale 1. timestep_shift 500. 1 steps. 1024x1024
208    NitroSDVibrant(NitroSDVibrantWeight),
209    /// sgm_uniform scheduler. cfg_scale 1. timestep_shift 400. 1 steps. 1024x1024
210    DiffInstructStar(DiffInstructStarWeight),
211    /// Enabled [crate::api::SampleMethod::EULER_SAMPLE_METHOD]. cfg_scale 4.0. 20 steps.
212    ChromaRadiance(ChromaRadianceWeight),
213    /// cfg_scale 9.0. 20 steps. 1024x1024
214    SSD1B(SSD1BWeight),
215    /// Requires access rights to <https://huggingface.co/black-forest-labs/FLUX.2-dev> providing a token via [crate::util::set_hf_token]
216    /// Enabled [crate::api::SampleMethod::EULER_SAMPLE_METHOD]. cfg_scale 1.0. Flash attention enabled. Offload params to CPU enabled. 20 steps. 512x512. Vae-tiling enabled.
217    Flux2Dev(Flux2Weight),
218    /// Requires access rights to <https://huggingface.co/black-forest-labs/FLUX.1-schnell> providing a token via [crate::util::set_hf_token]
219    /// cfg_scale 1.0. 9 steps. Flash attention enabled. 1024x412. Vae-tiling enabled.
220    ZImageTurbo(ZImageTurboWeight),
221    /// Enabled [crate::api::SampleMethod::EULER_SAMPLE_METHOD]. cfg_scale 2.5. flow_shift 3.0. Flash attention enabled. Offload params to CPU enabled. 20 steps. 1024x1024. Vae-tiling enabled.
222    QwenImage(QwenImageWeight),
223    /// Requires access rights to <https://huggingface.co/black-forest-labs/FLUX.1-schnel> providing a token via [crate::util::set_hf_token]
224    /// cfg_scale 5.0. Flash attention enabled. Offload params to CPU enabled. 20 steps. Vae-tiling enabled. 512x512.
225    OvisImage(OvisImageWeight),
226    /// lykon/dreamshaper-xl-v2-turbo is a Stable Diffusion model that has been fine-tuned on stabilityai/stable-diffusion-xl-base-1.0.
227    /// guidance_scale 2.0. 6 steps. 1024x1024. Vae-tiling enabled. Enabled [crate::api::SampleMethod::DPM2_SAMPLE_METHOD]
228    DreamShaperXL2_1Turbo,
229    /// Requires access rights to <https://huggingface.co/black-forest-labs/FLUX.1-schnell> providing a token via [crate::util::set_hf_token]
230    /// Enabled [crate::api::SampleMethod::DPM2_SAMPLE_METHOD] and [crate::api::Scheduler::SMOOTHSTEP_SCHEDULER]. cfg_scale 1.0. 3 steps. Flash attention enabled. 1024x512. Vae-tiling enabled.
231    TwinFlowZImageTurboExp(TwinFlowZImageTurboExpWeight),
232}
233
234impl Preset {
235    fn try_configs_builder(self) -> Result<(ConfigBuilder, ModelConfigBuilder), ApiError> {
236        #[allow(unused_mut)]
237        let mut preset = match self {
238            Preset::StableDiffusion1_4 => stable_diffusion_1_4(),
239            Preset::StableDiffusion1_5 => stable_diffusion_1_5(),
240            Preset::StableDiffusion2_1 => stable_diffusion_2_1(),
241            Preset::StableDiffusion3Medium => stable_diffusion_3_medium(),
242            Preset::SDXLBase1_0 => sdxl_base_1_0(),
243            Preset::Flux1Dev(sd_type_t) => flux_1_dev(sd_type_t),
244            Preset::Flux1Schnell(sd_type_t) => flux_1_schnell(sd_type_t),
245            Preset::SDTurbo => sd_turbo(),
246            Preset::SDXLTurbo1_0 => sdxl_turbo_1_0(),
247            Preset::StableDiffusion3_5Large => stable_diffusion_3_5_large(),
248            Preset::StableDiffusion3_5Medium => stable_diffusion_3_5_medium(),
249            Preset::StableDiffusion3_5LargeTurbo => stable_diffusion_3_5_large_turbo(),
250            Preset::JuggernautXL11 => juggernaut_xl_11(),
251            Preset::Flux1Mini(sd_type_t) => flux_1_mini(sd_type_t),
252            Preset::Chroma(sd_type_t) => chroma(sd_type_t),
253            Preset::NitroSDRealism(sd_type_t) => nitro_sd_realism(sd_type_t),
254            Preset::NitroSDVibrant(sd_type_t) => nitro_sd_vibrant(sd_type_t),
255            Preset::DiffInstructStar(sd_type_t) => diff_instruct_star(sd_type_t),
256            Preset::ChromaRadiance(sd_type_t) => chroma_radiance(sd_type_t),
257            Preset::SSD1B(sd_type_t) => ssd_1b(sd_type_t),
258            Preset::Flux2Dev(sd_type_t) => flux_2_dev(sd_type_t),
259            Preset::ZImageTurbo(sd_type_t) => z_image_turbo(sd_type_t),
260            Preset::QwenImage(sd_type_t) => qwen_image(sd_type_t),
261            Preset::OvisImage(sd_type_t) => ovis_image(sd_type_t),
262            Preset::DreamShaperXL2_1Turbo => dream_shaper_xl_2_1_turbo(),
263            Preset::TwinFlowZImageTurboExp(sd_type_t) => twinflow_z_image_turbo(sd_type_t),
264        };
265
266        // Metal workaround.
267        // See https://github.com/leejet/stable-diffusion.cpp/issues/1040#issuecomment-3623644576
268        #[cfg(feature = "metal")]
269        {
270            if let Ok((_, model_config)) = &mut preset {
271                model_config.clip_on_cpu(true);
272            };
273        }
274        preset
275    }
276}
277
278/// Configs tuple used by [crate::modifier]
279pub type ConfigsBuilder = (ConfigBuilder, ModelConfigBuilder);
280
281/// Returned by [PresetBuilder::build]
282pub type Configs = (Config, ModelConfig);
283
284/// Helper functions that modifies the [ConfigBuilder] See [crate::modifier]
285type ModifierFunction = dyn FnOnce(ConfigsBuilder) -> Result<ConfigsBuilder, ApiError>;
286
287#[derive(Builder)]
288#[builder(
289    name = "PresetBuilder",
290    pattern = "owned",
291    setter(into),
292    build_fn(name = "internal_build", private, error = "ConfigBuilderError")
293)]
294/// Helper struct for [ConfigBuilder]
295pub struct PresetConfig {
296    prompt: String,
297    preset: Preset,
298    #[builder(private, default = "Vec::new()")]
299    modifiers: Vec<Box<ModifierFunction>>,
300}
301
302impl PresetBuilder {
303    /// Add modifier that will apply in sequence
304    pub fn with_modifier<F>(mut self, f: F) -> Self
305    where
306        F: FnOnce(ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> + 'static,
307    {
308        if self.modifiers.is_none() {
309            self.modifiers = Some(Vec::new());
310        }
311        self.modifiers.as_mut().unwrap().push(Box::new(f));
312        self
313    }
314
315    pub fn build(self) -> Result<Configs, ConfigBuilderError> {
316        let preset = self.internal_build()?;
317        let configs: ConfigsBuilder = preset
318            .try_into()
319            .map_err(|err: ApiError| ConfigBuilderError::ValidationError(err.to_string()))?;
320        let config = configs.0.build()?;
321        let config_model = configs.1.build()?;
322
323        Ok((config, config_model))
324    }
325}
326
327impl TryFrom<PresetConfig> for ConfigsBuilder {
328    type Error = ApiError;
329
330    fn try_from(value: PresetConfig) -> Result<Self, Self::Error> {
331        let mut configs_builder = value.preset.try_configs_builder()?;
332        for m in value.modifiers {
333            configs_builder = m(configs_builder)?;
334        }
335        configs_builder.0.prompt(value.prompt);
336        Ok(configs_builder)
337    }
338}
339
340#[cfg(test)]
341mod tests {
342    use crate::{
343        api::gen_img,
344        preset::{
345            ChromaRadianceWeight, ChromaWeight, DiffInstructStarWeight, Flux1MiniWeight,
346            Flux1Weight, Flux2Weight, NitroSDRealismWeight, NitroSDVibrantWeight, OvisImageWeight,
347            QwenImageWeight, SSD1BWeight, TwinFlowZImageTurboExpWeight, ZImageTurboWeight,
348        },
349        util::set_hf_token,
350    };
351
352    use super::{Preset, PresetBuilder};
353    static PROMPT: &str = "a lovely dinosaur made by crochet";
354
355    fn run(preset: Preset) {
356        let (config, mut model_config) = PresetBuilder::default()
357            .preset(preset)
358            .prompt(PROMPT)
359            .build()
360            .unwrap();
361        gen_img(&config, &mut model_config).unwrap();
362    }
363
364    #[ignore]
365    #[test]
366    fn test_stable_diffusion_1_4() {
367        run(Preset::StableDiffusion1_4);
368    }
369
370    #[ignore]
371    #[test]
372    fn test_stable_diffusion_1_5() {
373        run(Preset::StableDiffusion1_5);
374    }
375
376    #[ignore]
377    #[test]
378    fn test_stable_diffusion_2_1() {
379        run(Preset::StableDiffusion2_1);
380    }
381
382    #[ignore]
383    #[test]
384    fn test_stable_diffusion_3_medium_fp16() {
385        set_hf_token(include_str!("../token.txt"));
386        run(Preset::StableDiffusion3Medium);
387    }
388
389    #[ignore]
390    #[test]
391    fn test_sdxl_base_1_0() {
392        run(Preset::SDXLBase1_0);
393    }
394
395    #[ignore]
396    #[test]
397    fn test_flux_1_dev() {
398        set_hf_token(include_str!("../token.txt"));
399        run(Preset::Flux1Dev(Flux1Weight::Q2_K));
400    }
401
402    #[ignore]
403    #[test]
404    fn test_flux_1_schnell() {
405        set_hf_token(include_str!("../token.txt"));
406        run(Preset::Flux1Schnell(Flux1Weight::Q2_K));
407    }
408
409    #[ignore]
410    #[test]
411    fn test_sd_turbo() {
412        run(Preset::SDTurbo);
413    }
414
415    #[ignore]
416    #[test]
417    fn test_sdxl_turbo_1_0_fp16() {
418        run(Preset::SDXLTurbo1_0);
419    }
420
421    #[ignore]
422    #[test]
423    fn test_stable_diffusion_3_5_medium_fp16() {
424        set_hf_token(include_str!("../token.txt"));
425        run(Preset::StableDiffusion3_5Medium);
426    }
427
428    #[ignore]
429    #[test]
430    fn test_stable_diffusion_3_5_large_fp16() {
431        set_hf_token(include_str!("../token.txt"));
432        run(Preset::StableDiffusion3_5Large);
433    }
434
435    #[ignore]
436    #[test]
437    fn test_stable_diffusion_3_5_large_turbo_fp16() {
438        set_hf_token(include_str!("../token.txt"));
439        run(Preset::StableDiffusion3_5LargeTurbo);
440    }
441
442    #[ignore]
443    #[test]
444    fn test_juggernaut_xl_11() {
445        set_hf_token(include_str!("../token.txt"));
446        run(Preset::JuggernautXL11);
447    }
448
449    #[ignore]
450    #[test]
451    fn test_flux_1_mini() {
452        set_hf_token(include_str!("../token.txt"));
453        run(Preset::Flux1Mini(Flux1MiniWeight::Q2_K));
454    }
455
456    #[ignore]
457    #[test]
458    fn test_chroma() {
459        set_hf_token(include_str!("../token.txt"));
460        run(Preset::Chroma(ChromaWeight::Q4_0));
461    }
462
463    #[ignore]
464    #[test]
465    fn test_nitro_sd_realism() {
466        run(Preset::NitroSDRealism(NitroSDRealismWeight::Q8_0));
467    }
468
469    #[ignore]
470    #[test]
471    fn test_nitro_sd_vibrant() {
472        run(Preset::NitroSDVibrant(NitroSDVibrantWeight::Q8_0));
473    }
474
475    #[ignore]
476    #[test]
477    fn test_diff_instruct_star() {
478        run(Preset::DiffInstructStar(DiffInstructStarWeight::Q8_0));
479    }
480
481    #[ignore]
482    #[test]
483    fn test_chroma_radiance() {
484        run(Preset::ChromaRadiance(ChromaRadianceWeight::Q8_0));
485    }
486
487    #[ignore]
488    #[test]
489    fn test_ssd_1b() {
490        run(Preset::SSD1B(SSD1BWeight::F8_E4M3));
491    }
492
493    #[ignore]
494    #[test]
495    fn test_flux_2_dev() {
496        set_hf_token(include_str!("../token.txt"));
497        run(Preset::Flux2Dev(Flux2Weight::Q2_K));
498    }
499
500    #[ignore]
501    #[test]
502    fn test_z_image_turbo() {
503        set_hf_token(include_str!("../token.txt"));
504        run(Preset::ZImageTurbo(ZImageTurboWeight::Q2_K));
505    }
506
507    #[ignore]
508    #[test]
509    fn test_qwen_image() {
510        run(Preset::QwenImage(QwenImageWeight::Q2_K));
511    }
512
513    #[ignore]
514    #[test]
515    fn test_ovis_image() {
516        set_hf_token(include_str!("../token.txt"));
517        run(Preset::OvisImage(OvisImageWeight::Q4_0));
518    }
519
520    #[ignore]
521    #[test]
522    fn test_dreamshaper_xl_2_1_turbo() {
523        run(Preset::DreamShaperXL2_1Turbo);
524    }
525
526    #[ignore]
527    #[test]
528    fn test_twinflow_z_image_turbo_exp() {
529        set_hf_token(include_str!("../token.txt"));
530        run(Preset::TwinFlowZImageTurboExp(
531            TwinFlowZImageTurboExpWeight::Q3_K,
532        ));
533    }
534}