1use derive_builder::Builder;
2use hf_hub::api::sync::ApiError;
3
4use crate::{
5 api::{self, Config, ConfigBuilder, ConfigBuilderError, ModelConfig, ModelConfigBuilder},
6 preset_builder::{
7 chroma, flux_1_dev, flux_1_mini, flux_1_schnell, juggernaut_xl_11, sd_turbo, sdxl_base_1_0,
8 sdxl_turbo_1_0_fp16, stable_diffusion_1_4, stable_diffusion_1_5, stable_diffusion_2_1,
9 stable_diffusion_3_5_large_fp16, stable_diffusion_3_5_large_turbo_fp16,
10 stable_diffusion_3_5_medium_fp16, stable_diffusion_3_medium_fp16,
11 },
12};
13
14#[non_exhaustive]
15#[derive(Debug, Clone, Copy)]
16pub enum Preset {
18 StableDiffusion1_4,
19 StableDiffusion1_5,
20 StableDiffusion2_1,
23 StableDiffusion3MediumFp16,
26 StableDiffusion3_5MediumFp16,
29 StableDiffusion3_5LargeFp16,
32 StableDiffusion3_5LargeTurboFp16,
35 SDXLBase1_0,
36 SDTurbo,
38 SDXLTurbo1_0Fp16,
40 Flux1Dev(api::WeightType),
43 Flux1Schnell(api::WeightType),
46 Flux1Mini(api::WeightType),
49 JuggernautXL11,
52 Chroma(api::WeightType),
56}
57
58impl Preset {
59 fn try_configs_builder(self) -> Result<(ConfigBuilder, ModelConfigBuilder), ApiError> {
60 match self {
61 Preset::StableDiffusion1_4 => stable_diffusion_1_4(),
62 Preset::StableDiffusion1_5 => stable_diffusion_1_5(),
63 Preset::StableDiffusion2_1 => stable_diffusion_2_1(),
64 Preset::StableDiffusion3MediumFp16 => stable_diffusion_3_medium_fp16(),
65 Preset::SDXLBase1_0 => sdxl_base_1_0(),
66 Preset::Flux1Dev(sd_type_t) => flux_1_dev(sd_type_t),
67 Preset::Flux1Schnell(sd_type_t) => flux_1_schnell(sd_type_t),
68 Preset::SDTurbo => sd_turbo(),
69 Preset::SDXLTurbo1_0Fp16 => sdxl_turbo_1_0_fp16(),
70 Preset::StableDiffusion3_5LargeFp16 => stable_diffusion_3_5_large_fp16(),
71 Preset::StableDiffusion3_5MediumFp16 => stable_diffusion_3_5_medium_fp16(),
72 Preset::StableDiffusion3_5LargeTurboFp16 => stable_diffusion_3_5_large_turbo_fp16(),
73 Preset::JuggernautXL11 => juggernaut_xl_11(),
74 Preset::Flux1Mini(sd_type_t) => flux_1_mini(sd_type_t),
75 Preset::Chroma(sd_type_t) => chroma(sd_type_t),
76 }
77 }
78}
79
80pub type ConfigsBuilder = (ConfigBuilder, ModelConfigBuilder);
82
83pub type Configs = (Config, ModelConfig);
85
86pub type Modifier = fn(ConfigsBuilder) -> Result<ConfigsBuilder, ApiError>;
88
89#[derive(Debug, Clone, Builder)]
90#[builder(
91 name = "PresetBuilder",
92 setter(into),
93 build_fn(name = "internal_build", private, error = "ConfigBuilderError")
94)]
95pub struct PresetConfig {
97 prompt: String,
98 preset: Preset,
99 #[builder(private, default = "Vec::new()")]
100 modifiers: Vec<Modifier>,
101}
102
103impl PresetBuilder {
104 pub fn with_modifier(&mut self, f: Modifier) -> &mut Self {
106 if self.modifiers.is_none() {
107 self.modifiers = Some(Vec::new());
108 }
109 self.modifiers.as_mut().unwrap().push(f);
110 self
111 }
112
113 pub fn build(&mut self) -> Result<Configs, ConfigBuilderError> {
114 let preset = self.internal_build()?;
115 let configs: ConfigsBuilder = preset
116 .try_into()
117 .map_err(|err: ApiError| ConfigBuilderError::ValidationError(err.to_string()))?;
118 let config = configs.0.build()?;
119 let config_model = configs.1.build()?;
120
121 Ok((config, config_model))
122 }
123}
124
125impl TryFrom<PresetConfig> for ConfigsBuilder {
126 type Error = ApiError;
127
128 fn try_from(value: PresetConfig) -> Result<Self, Self::Error> {
129 let mut configs_builder = value.preset.try_configs_builder()?;
130 for m in value.modifiers {
131 configs_builder = m(configs_builder)?;
132 }
133 configs_builder.0.prompt(value.prompt);
134 Ok(configs_builder)
135 }
136}
137
138#[cfg(test)]
139mod tests {
140 use crate::{
141 api::{self, gen_img},
142 util::set_hf_token,
143 };
144
145 use super::{Preset, PresetBuilder};
146 static PROMPT: &str = "a lovely cat holding a sign says 'diffusion-rs'";
147
148 fn run(preset: Preset) {
149 let (mut config, mut model_config) = PresetBuilder::default()
150 .preset(preset)
151 .prompt(PROMPT)
152 .build()
153 .unwrap();
154 gen_img(&mut config, &mut model_config).unwrap();
155 }
156
157 #[ignore]
158 #[test]
159 fn test_stable_diffusion_1_4() {
160 run(Preset::StableDiffusion1_4);
161 }
162
163 #[ignore]
164 #[test]
165 fn test_stable_diffusion_1_5() {
166 run(Preset::StableDiffusion1_5);
167 }
168
169 #[ignore]
170 #[test]
171 fn test_stable_diffusion_2_1() {
172 run(Preset::StableDiffusion2_1);
173 }
174
175 #[ignore]
176 #[test]
177 fn test_stable_diffusion_3_medium_fp16() {
178 set_hf_token(include_str!("../token.txt"));
179 run(Preset::StableDiffusion3MediumFp16);
180 }
181
182 #[ignore]
183 #[test]
184 fn test_sdxl_base_1_0() {
185 run(Preset::SDXLBase1_0);
186 }
187
188 #[ignore]
189 #[test]
190 fn test_flux_1_dev() {
191 set_hf_token(include_str!("../token.txt"));
192 run(Preset::Flux1Dev(api::WeightType::SD_TYPE_Q2_K));
193 }
194
195 #[ignore]
196 #[test]
197 fn test_flux_1_schnell() {
198 set_hf_token(include_str!("../token.txt"));
199 run(Preset::Flux1Schnell(api::WeightType::SD_TYPE_Q2_K));
200 }
201
202 #[ignore]
203 #[test]
204 fn test_sd_turbo() {
205 run(Preset::SDTurbo);
206 }
207
208 #[ignore]
209 #[test]
210 fn test_sdxl_turbo_1_0_fp16() {
211 run(Preset::SDXLTurbo1_0Fp16);
212 }
213
214 #[ignore]
215 #[test]
216 fn test_stable_diffusion_3_5_medium_fp16() {
217 set_hf_token(include_str!("../token.txt"));
218 run(Preset::StableDiffusion3_5MediumFp16);
219 }
220
221 #[ignore]
222 #[test]
223 fn test_stable_diffusion_3_5_large_fp16() {
224 set_hf_token(include_str!("../token.txt"));
225 run(Preset::StableDiffusion3_5LargeFp16);
226 }
227
228 #[ignore]
229 #[test]
230 fn test_stable_diffusion_3_5_large_turbo_fp16() {
231 set_hf_token(include_str!("../token.txt"));
232 run(Preset::StableDiffusion3_5LargeTurboFp16);
233 }
234
235 #[ignore]
236 #[test]
237 fn test_juggernaut_xl_11() {
238 set_hf_token(include_str!("../token.txt"));
239 run(Preset::JuggernautXL11);
240 }
241
242 #[ignore]
243 #[test]
244 fn test_flux_1_mini() {
245 set_hf_token(include_str!("../token.txt"));
246 run(Preset::Flux1Mini(api::WeightType::SD_TYPE_Q8_0));
247 }
248
249 #[ignore]
250 #[test]
251 fn test_chroma() {
252 set_hf_token(include_str!("../token.txt"));
253 run(Preset::Chroma(api::WeightType::SD_TYPE_Q4_0));
254 }
255}