1use hf_hub::api::sync::ApiError;
2
3use crate::{
4 api::{LoraSpec, PreviewType, SampleMethod},
5 preset::ConfigsBuilder,
6 util::download_file_hf_hub,
7};
8
9pub fn real_esrgan_x4plus_anime_6_b(
11 mut builder: ConfigsBuilder,
12) -> Result<ConfigsBuilder, ApiError> {
13 let upscaler_path = download_file_hf_hub(
14 "ximso/RealESRGAN_x4plus_anime_6B",
15 "RealESRGAN_x4plus_anime_6B.pth",
16 )?;
17 builder.1.upscale_model(upscaler_path);
18 Ok(builder)
19}
20
21pub fn sdxl_vae_fp16_fix(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
23 let vae_path = download_file_hf_hub("madebyollin/sdxl-vae-fp16-fix", "sdxl.vae.safetensors")?;
24 builder.1.vae(vae_path);
25 Ok(builder)
26}
27
28pub fn taesd(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
30 let taesd_path =
31 download_file_hf_hub("madebyollin/taesd", "diffusion_pytorch_model.safetensors")?;
32 builder.1.taesd(taesd_path);
33 Ok(builder)
34}
35
36pub fn taesd_xl(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
38 let taesd_path =
39 download_file_hf_hub("madebyollin/taesdxl", "diffusion_pytorch_model.safetensors")?;
40 builder.1.taesd(taesd_path);
41 Ok(builder)
42}
43
44pub fn hybrid_taesd(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
46 let taesd_path = download_file_hf_hub(
47 "cqyan/hybrid-sd-tinyvae",
48 "diffusion_pytorch_model.safetensors",
49 )?;
50 builder.1.taesd(taesd_path);
51 Ok(builder)
52}
53
54pub fn hybrid_taesd_xl(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
56 let taesd_path = download_file_hf_hub(
57 "cqyan/hybrid-sd-tinyvae-xl",
58 "diffusion_pytorch_model.safetensors",
59 )?;
60 builder.1.taesd(taesd_path);
61 Ok(builder)
62}
63
64pub fn lcm_lora_sd_1_5(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
67 let lora_path = download_file_hf_hub(
68 "latent-consistency/lcm-lora-sdv1-5",
69 "pytorch_lora_weights.safetensors",
70 )?;
71 builder.1.lora_models(
72 lora_path.parent().unwrap(),
73 vec![LoraSpec {
74 file_name: "pytorch_lora_weights".to_string(),
75 is_high_noise: false,
76 multiplier: 1.0,
77 }],
78 );
79 builder.0.cfg_scale(1.).steps(8);
80 Ok(builder)
81}
82
83pub fn lcm_lora_sdxl_base_1_0(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
86 let lora_path = download_file_hf_hub(
87 "latent-consistency/lcm-lora-sdxl",
88 "pytorch_lora_weights.safetensors",
89 )?;
90
91 builder.1.lora_models(
92 lora_path.parent().unwrap(),
93 vec![LoraSpec {
94 file_name: "pytorch_lora_weights".to_string(),
95 is_high_noise: false,
96 multiplier: 1.0,
97 }],
98 );
99 builder
100 .0
101 .cfg_scale(2.)
102 .steps(8)
103 .sampling_method(SampleMethod::LCM_SAMPLE_METHOD);
104 Ok(builder)
105}
106
107pub fn t5xxl_fp8_flux_1(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
109 let t5xxl_path = download_file_hf_hub(
110 "comfyanonymous/flux_text_encoders",
111 "t5xxl_fp8_e4m3fn.safetensors",
112 )?;
113
114 builder.1.t5xxl(t5xxl_path);
115 Ok(builder)
116}
117
118pub fn t5xxl_fp16_flux_1(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
121 let t5xxl_path = download_file_hf_hub(
122 "comfyanonymous/flux_text_encoders",
123 "t5xxl_fp16.safetensors",
124 )?;
125
126 builder.1.t5xxl(t5xxl_path);
127 Ok(builder)
128}
129
130pub fn t5xxl_q2_k_flux_1(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
132 let t5xxl_path = download_file_hf_hub("Green-Sky/flux.1-schnell-GGUF", "t5xxl_q2_k.gguf")?;
133
134 builder.1.t5xxl(t5xxl_path);
135 Ok(builder)
136}
137
138pub fn t5xxl_q3_k_flux_1(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
140 let t5xxl_path = download_file_hf_hub("Green-Sky/flux.1-schnell-GGUF", "t5xxl_q3_k.gguf")?;
141
142 builder.1.t5xxl(t5xxl_path);
143 Ok(builder)
144}
145
146pub fn t5xxl_q4_k_flux_1(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
149 let t5xxl_path = download_file_hf_hub("Green-Sky/flux.1-schnell-GGUF", "t5xxl_q4_k.gguf")?;
150
151 builder.1.t5xxl(t5xxl_path);
152 Ok(builder)
153}
154
155pub fn t5xxl_q8_0_flux_1(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
157 let t5xxl_path = download_file_hf_hub("Green-Sky/flux.1-schnell-GGUF", "t5xxl_q8_0.gguf")?;
158
159 builder.1.t5xxl(t5xxl_path);
160 Ok(builder)
161}
162
163pub fn offload_params_to_cpu(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
164 builder.1.offload_params_to_cpu(true);
165 Ok(builder)
166}
167
168pub fn lcm_lora_ssd_1b(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
171 let lora_path = download_file_hf_hub(
172 "kylielee505/mylcmlorassd",
173 "pytorch_lora_weights.safetensors",
174 )?;
175 builder.1.lora_models(
176 lora_path.parent().unwrap(),
177 vec![LoraSpec {
178 file_name: "pytorch_lora_weights".to_string(),
179 is_high_noise: false,
180 multiplier: 1.0,
181 }],
182 );
183 builder.0.cfg_scale(1.).steps(8);
184 Ok(builder)
185}
186
187pub fn vae_tiling(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
189 builder.1.vae_tiling(true);
190 Ok(builder)
191}
192
193pub fn preview_proj(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
195 builder.0.preview_mode(PreviewType::PREVIEW_PROJ);
196 Ok(builder)
197}
198
199pub fn preview_tae(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
201 builder.0.preview_mode(PreviewType::PREVIEW_TAE);
202 Ok(builder)
203}
204
205pub fn preview_vae(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
207 builder.0.preview_mode(PreviewType::PREVIEW_VAE);
208 Ok(builder)
209}
210
211pub fn enable_easycache(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
213 builder.1.easy_cache(true);
214 Ok(builder)
215}
216
217pub fn enable_flash_attention(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
219 builder.1.flash_attention(true);
220 Ok(builder)
221}
222
223#[cfg(test)]
224mod tests {
225 use crate::{
226 api::gen_img,
227 modifier::{
228 enable_easycache, enable_flash_attention, lcm_lora_ssd_1b, offload_params_to_cpu,
229 preview_proj, preview_tae, preview_vae, vae_tiling,
230 },
231 preset::{Flux1Weight, Modifier, Preset, PresetBuilder},
232 util::set_hf_token,
233 };
234
235 use super::{
236 hybrid_taesd, hybrid_taesd_xl, lcm_lora_sd_1_5, lcm_lora_sdxl_base_1_0, taesd, taesd_xl,
237 };
238
239 static PROMPT: &str = "a lovely dynosaur made by crochet";
240
241 fn run(preset: Preset, m: Modifier) {
242 let (mut config, mut model_config) = PresetBuilder::default()
243 .preset(preset)
244 .prompt(PROMPT)
245 .with_modifier(m)
246 .build()
247 .unwrap();
248 gen_img(&mut config, &mut model_config).unwrap();
249 }
250
251 #[ignore]
252 #[test]
253 fn test_taesd() {
254 run(Preset::StableDiffusion1_5, taesd);
255 }
256
257 #[ignore]
258 #[test]
259 fn test_taesd_xl() {
260 run(Preset::SDXLTurbo1_0Fp16, taesd_xl);
261 }
262
263 #[ignore]
264 #[test]
265 fn test_hybrid_taesd() {
266 run(Preset::StableDiffusion1_5, hybrid_taesd);
267 }
268
269 #[ignore]
270 #[test]
271 fn test_hybrid_taesd_xl() {
272 run(Preset::SDXLTurbo1_0Fp16, hybrid_taesd_xl);
273 }
274
275 #[ignore]
276 #[test]
277 fn test_lcm_lora_sd_1_5() {
278 run(Preset::StableDiffusion1_5, lcm_lora_sd_1_5);
279 }
280
281 #[ignore]
282 #[test]
283 fn test_lcm_lora_sdxl_base_1_0() {
284 run(Preset::SDXLBase1_0, lcm_lora_sdxl_base_1_0);
285 }
286
287 #[ignore]
288 #[test]
289 fn test_offload_params_to_cpu() {
290 set_hf_token(include_str!("../token.txt"));
291 run(
292 Preset::Flux1Schnell(Flux1Weight::Q2_K),
293 offload_params_to_cpu,
294 );
295 }
296
297 #[ignore]
298 #[test]
299 fn test_lcm_lora_ssd_1b() {
300 run(
301 Preset::SSD1B(crate::preset::SSD1BWeight::F8_E4M3),
302 lcm_lora_ssd_1b,
303 );
304 }
305
306 #[ignore]
307 #[test]
308 fn test_vae_tiling() {
309 run(
310 Preset::SSD1B(crate::preset::SSD1BWeight::F8_E4M3),
311 vae_tiling,
312 );
313 }
314
315 #[ignore]
316 #[test]
317 fn test_preview_proj() {
318 run(Preset::SDXLTurbo1_0Fp16, preview_proj);
319 }
320
321 #[ignore]
322 #[test]
323 fn test_preview_tae() {
324 run(Preset::SDXLTurbo1_0Fp16, preview_tae);
325 }
326
327 #[ignore]
328 #[test]
329 fn test_preview_vae() {
330 run(Preset::SDXLTurbo1_0Fp16, preview_vae);
331 }
332
333 #[ignore]
334 #[test]
335 fn test_easy_cache() {
336 set_hf_token(include_str!("../token.txt"));
337 run(
338 Preset::Flux1Mini(crate::preset::Flux1MiniWeight::Q2_K),
339 enable_easycache,
340 );
341 }
342
343 #[ignore]
344 #[test]
345 fn test_flash_attention() {
346 set_hf_token(include_str!("../token.txt"));
347 run(
348 Preset::Flux1Mini(crate::preset::Flux1MiniWeight::Q2_K),
349 enable_flash_attention,
350 );
351 }
352}