diffusion_rs/
preset.rs

1use derive_builder::Builder;
2use hf_hub::api::sync::ApiError;
3
4use crate::{
5    api::{self, Config, ConfigBuilder, ConfigBuilderError, ModelConfig, ModelConfigBuilder},
6    preset_builder::{
7        chroma, flux_1_dev, flux_1_mini, flux_1_schnell, juggernaut_xl_11, sd_turbo, sdxl_base_1_0,
8        sdxl_turbo_1_0_fp16, stable_diffusion_1_4, stable_diffusion_1_5, stable_diffusion_2_1,
9        stable_diffusion_3_5_large_fp16, stable_diffusion_3_5_large_turbo_fp16,
10        stable_diffusion_3_5_medium_fp16, stable_diffusion_3_medium_fp16,
11    },
12};
13
14#[non_exhaustive]
15#[derive(Debug, Clone, Copy)]
16/// Models ready to use
17pub enum Preset {
18    StableDiffusion1_4,
19    StableDiffusion1_5,
20    /// <https://huggingface.co/stabilityai/stable-diffusion-2-1> model.
21    ///  Vae-tiling enabled. 768x768.
22    StableDiffusion2_1,
23    /// Requires access rights to <https://huggingface.co/stabilityai/stable-diffusion-3-medium> providing a token via [crate::util::set_hf_token]
24    /// Vae-tiling enabled. 1024x1024. Enabled [api::SampleMethod::EULER]. 30 steps.
25    StableDiffusion3MediumFp16,
26    /// Requires access rights to <https://huggingface.co/stabilityai/stable-diffusion-3.5-medium> providing a token via [crate::util::set_hf_token]
27    /// Vae-tiling enabled. 1024x1024. Enabled [api::SampleMethod::EULER]. cfg_scale 4.5. 40 steps.
28    StableDiffusion3_5MediumFp16,
29    /// Requires access rights to <https://huggingface.co/stabilityai/stable-diffusion-3.5-large> providing a token via [crate::util::set_hf_token]
30    /// Vae-tiling enabled. 1024x1024. Enabled [api::SampleMethod::EULER]. cfg_scale 4.5. 28 steps.
31    StableDiffusion3_5LargeFp16,
32    /// Requires access rights to <https://huggingface.co/stabilityai/stable-diffusion-3.5-large-turbo> providing a token via [crate::util::set_hf_token]
33    /// Vae-tiling enabled. 1024x1024. Enabled [api::SampleMethod::EULER]. cfg_scale 0. 4 steps.
34    StableDiffusion3_5LargeTurboFp16,
35    SDXLBase1_0,
36    /// cfg_scale 1. guidance 0. 4 steps
37    SDTurbo,
38    /// cfg_scale 1. guidance 0. 4 steps
39    SDXLTurbo1_0Fp16,
40    /// Requires access rights to <https://huggingface.co/black-forest-labs/FLUX.1-dev> providing a token via [crate::util::set_hf_token]
41    /// Vae-tiling enabled. 1024x1024. Enabled [api::SampleMethod::EULER]. 28 steps.
42    Flux1Dev(api::WeightType),
43    /// Requires access rights to <https://huggingface.co/black-forest-labs/FLUX.1-schnell> providing a token via [crate::util::set_hf_token]
44    /// Vae-tiling enabled. 1024x1024. Enabled [api::SampleMethod::EULER]. 4 steps.
45    Flux1Schnell(api::WeightType),
46    /// A 3.2B param rectified flow transformer distilled from FLUX.1-dev <https://huggingface.co/TencentARC/flux-mini> <https://huggingface.co/HyperX-Sentience/Flux-Mini-GGUF>
47    /// Vae-tiling enabled. 512x512. Enabled [api::SampleMethod::EULER]. cfg_scale 1. 20 steps.
48    Flux1Mini(api::WeightType),
49    /// Requires access rights to <https://huggingface.co/RunDiffusion/Juggernaut-XI-v11> providing a token via [crate::util::set_hf_token]
50    /// Vae-tiling enabled. 1024x1024. Enabled [api::SampleMethod::DPM2]. guidance 6. 20 steps
51    JuggernautXL11,
52    /// Chroma is an 8.9B parameter model based on FLUX.1-schnell
53    /// Requires access rights to <https://huggingface.co/black-forest-labs/FLUX.1-dev> providing a token via [crate::util::set_hf_token]
54    /// Vae-tiling enabled. 512x512. Enabled [api::SampleMethod::EULER]. cfg_scale 4. 20 steps
55    Chroma(api::WeightType),
56}
57
58impl Preset {
59    fn try_configs_builder(self) -> Result<(ConfigBuilder, ModelConfigBuilder), ApiError> {
60        match self {
61            Preset::StableDiffusion1_4 => stable_diffusion_1_4(),
62            Preset::StableDiffusion1_5 => stable_diffusion_1_5(),
63            Preset::StableDiffusion2_1 => stable_diffusion_2_1(),
64            Preset::StableDiffusion3MediumFp16 => stable_diffusion_3_medium_fp16(),
65            Preset::SDXLBase1_0 => sdxl_base_1_0(),
66            Preset::Flux1Dev(sd_type_t) => flux_1_dev(sd_type_t),
67            Preset::Flux1Schnell(sd_type_t) => flux_1_schnell(sd_type_t),
68            Preset::SDTurbo => sd_turbo(),
69            Preset::SDXLTurbo1_0Fp16 => sdxl_turbo_1_0_fp16(),
70            Preset::StableDiffusion3_5LargeFp16 => stable_diffusion_3_5_large_fp16(),
71            Preset::StableDiffusion3_5MediumFp16 => stable_diffusion_3_5_medium_fp16(),
72            Preset::StableDiffusion3_5LargeTurboFp16 => stable_diffusion_3_5_large_turbo_fp16(),
73            Preset::JuggernautXL11 => juggernaut_xl_11(),
74            Preset::Flux1Mini(sd_type_t) => flux_1_mini(sd_type_t),
75            Preset::Chroma(sd_type_t) => chroma(sd_type_t),
76        }
77    }
78}
79
80/// Configs tuple used by [crate::modifier]
81pub type ConfigsBuilder = (ConfigBuilder, ModelConfigBuilder);
82
83/// Returned by [PresetBuilder::build]
84pub type Configs = (Config, ModelConfig);
85
86/// Helper functions that modifies the [ConfigBuilder] See [crate::modifier]
87pub type Modifier = fn(ConfigsBuilder) -> Result<ConfigsBuilder, ApiError>;
88
89#[derive(Debug, Clone, Builder)]
90#[builder(
91    name = "PresetBuilder",
92    setter(into),
93    build_fn(name = "internal_build", private, error = "ConfigBuilderError")
94)]
95/// Helper struct for [ConfigBuilder]
96pub struct PresetConfig {
97    prompt: String,
98    preset: Preset,
99    #[builder(private, default = "Vec::new()")]
100    modifiers: Vec<Modifier>,
101}
102
103impl PresetBuilder {
104    /// Add modifier that will apply in sequence
105    pub fn with_modifier(&mut self, f: Modifier) -> &mut Self {
106        if self.modifiers.is_none() {
107            self.modifiers = Some(Vec::new());
108        }
109        self.modifiers.as_mut().unwrap().push(f);
110        self
111    }
112
113    pub fn build(&mut self) -> Result<Configs, ConfigBuilderError> {
114        let preset = self.internal_build()?;
115        let configs: ConfigsBuilder = preset
116            .try_into()
117            .map_err(|err: ApiError| ConfigBuilderError::ValidationError(err.to_string()))?;
118        let config = configs.0.build()?;
119        let config_model = configs.1.build()?;
120
121        Ok((config, config_model))
122    }
123}
124
125impl TryFrom<PresetConfig> for ConfigsBuilder {
126    type Error = ApiError;
127
128    fn try_from(value: PresetConfig) -> Result<Self, Self::Error> {
129        let mut configs_builder = value.preset.try_configs_builder()?;
130        for m in value.modifiers {
131            configs_builder = m(configs_builder)?;
132        }
133        configs_builder.0.prompt(value.prompt);
134        Ok(configs_builder)
135    }
136}
137
138#[cfg(test)]
139mod tests {
140    use crate::{
141        api::{self, gen_img},
142        util::set_hf_token,
143    };
144
145    use super::{Preset, PresetBuilder};
146    static PROMPT: &str = "a lovely cat holding a sign says 'diffusion-rs'";
147
148    fn run(preset: Preset) {
149        let (mut config, mut model_config) = PresetBuilder::default()
150            .preset(preset)
151            .prompt(PROMPT)
152            .build()
153            .unwrap();
154        gen_img(&mut config, &mut model_config).unwrap();
155    }
156
157    #[ignore]
158    #[test]
159    fn test_stable_diffusion_1_4() {
160        run(Preset::StableDiffusion1_4);
161    }
162
163    #[ignore]
164    #[test]
165    fn test_stable_diffusion_1_5() {
166        run(Preset::StableDiffusion1_5);
167    }
168
169    #[ignore]
170    #[test]
171    fn test_stable_diffusion_2_1() {
172        run(Preset::StableDiffusion2_1);
173    }
174
175    #[ignore]
176    #[test]
177    fn test_stable_diffusion_3_medium_fp16() {
178        set_hf_token(include_str!("../token.txt"));
179        run(Preset::StableDiffusion3MediumFp16);
180    }
181
182    #[ignore]
183    #[test]
184    fn test_sdxl_base_1_0() {
185        run(Preset::SDXLBase1_0);
186    }
187
188    #[ignore]
189    #[test]
190    fn test_flux_1_dev() {
191        set_hf_token(include_str!("../token.txt"));
192        run(Preset::Flux1Dev(api::WeightType::SD_TYPE_Q2_K));
193    }
194
195    #[ignore]
196    #[test]
197    fn test_flux_1_schnell() {
198        set_hf_token(include_str!("../token.txt"));
199        run(Preset::Flux1Schnell(api::WeightType::SD_TYPE_Q2_K));
200    }
201
202    #[ignore]
203    #[test]
204    fn test_sd_turbo() {
205        run(Preset::SDTurbo);
206    }
207
208    #[ignore]
209    #[test]
210    fn test_sdxl_turbo_1_0_fp16() {
211        run(Preset::SDXLTurbo1_0Fp16);
212    }
213
214    #[ignore]
215    #[test]
216    fn test_stable_diffusion_3_5_medium_fp16() {
217        set_hf_token(include_str!("../token.txt"));
218        run(Preset::StableDiffusion3_5MediumFp16);
219    }
220
221    #[ignore]
222    #[test]
223    fn test_stable_diffusion_3_5_large_fp16() {
224        set_hf_token(include_str!("../token.txt"));
225        run(Preset::StableDiffusion3_5LargeFp16);
226    }
227
228    #[ignore]
229    #[test]
230    fn test_stable_diffusion_3_5_large_turbo_fp16() {
231        set_hf_token(include_str!("../token.txt"));
232        run(Preset::StableDiffusion3_5LargeTurboFp16);
233    }
234
235    #[ignore]
236    #[test]
237    fn test_juggernaut_xl_11() {
238        set_hf_token(include_str!("../token.txt"));
239        run(Preset::JuggernautXL11);
240    }
241
242    #[ignore]
243    #[test]
244    fn test_flux_1_mini() {
245        set_hf_token(include_str!("../token.txt"));
246        run(Preset::Flux1Mini(api::WeightType::SD_TYPE_Q8_0));
247    }
248
249    #[ignore]
250    #[test]
251    fn test_chroma() {
252        set_hf_token(include_str!("../token.txt"));
253        run(Preset::Chroma(api::WeightType::SD_TYPE_Q4_0));
254    }
255}