1use std::collections::HashMap;
2
3use hf_hub::api::sync::ApiError;
4use strum::IntoEnumIterator;
5
6use crate::{
7 api::{BackendDevice, LoraSpec, Module, PreviewType, SampleMethod},
8 preset::ConfigsBuilder,
9 util::download_file_hf_hub,
10};
11
12pub fn real_esrgan_x4plus_anime_6_b(
14 mut builder: ConfigsBuilder,
15) -> Result<ConfigsBuilder, ApiError> {
16 let upscaler_path = download_file_hf_hub(
17 "ximso/RealESRGAN_x4plus_anime_6B",
18 "RealESRGAN_x4plus_anime_6B.pth",
19 )?;
20 builder.1.upscale_model(upscaler_path);
21 Ok(builder)
22}
23
24pub fn sdxl_vae_fp16_fix(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
26 let vae_path = download_file_hf_hub("madebyollin/sdxl-vae-fp16-fix", "sdxl.vae.safetensors")?;
27 builder.1.vae(vae_path);
28 Ok(builder)
29}
30
31pub fn taesd(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
33 let taesd_path =
34 download_file_hf_hub("madebyollin/taesd", "diffusion_pytorch_model.safetensors")?;
35 builder.1.taesd(taesd_path);
36 Ok(builder)
37}
38
39pub fn taesd_xl(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
41 let taesd_path =
42 download_file_hf_hub("madebyollin/taesdxl", "diffusion_pytorch_model.safetensors")?;
43 builder.1.taesd(taesd_path);
44 Ok(builder)
45}
46
47pub fn hybrid_taesd(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
49 let taesd_path = download_file_hf_hub(
50 "cqyan/hybrid-sd-tinyvae",
51 "diffusion_pytorch_model.safetensors",
52 )?;
53 builder.1.taesd(taesd_path);
54 Ok(builder)
55}
56
57pub fn hybrid_taesd_xl(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
59 let taesd_path = download_file_hf_hub(
60 "cqyan/hybrid-sd-tinyvae-xl",
61 "diffusion_pytorch_model.safetensors",
62 )?;
63 builder.1.taesd(taesd_path);
64 Ok(builder)
65}
66
67pub fn lcm_lora_sd_1_5(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
70 let lora_path = download_file_hf_hub(
71 "latent-consistency/lcm-lora-sdv1-5",
72 "pytorch_lora_weights.safetensors",
73 )?;
74 builder.1.lora_models(
75 lora_path.parent().unwrap(),
76 vec![LoraSpec {
77 file_name: "pytorch_lora_weights".to_string(),
78 is_high_noise: false,
79 multiplier: 1.0,
80 }],
81 );
82 builder.0.cfg_scale(1.).steps(8);
83 Ok(builder)
84}
85
86pub fn lcm_lora_sdxl_base_1_0(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
89 let lora_path = download_file_hf_hub(
90 "latent-consistency/lcm-lora-sdxl",
91 "pytorch_lora_weights.safetensors",
92 )?;
93
94 builder.1.lora_models(
95 lora_path.parent().unwrap(),
96 vec![LoraSpec {
97 file_name: "pytorch_lora_weights".to_string(),
98 is_high_noise: false,
99 multiplier: 1.0,
100 }],
101 );
102 builder
103 .0
104 .cfg_scale(2.)
105 .steps(8)
106 .sampling_method(SampleMethod::LCM_SAMPLE_METHOD);
107 Ok(builder)
108}
109
110pub fn lora_pixel_art_sdxl_base_1_0(
112 mut builder: ConfigsBuilder,
113) -> Result<ConfigsBuilder, ApiError> {
114 let lora_path = download_file_hf_hub("nerijs/pixel-art-xl", "pixel-art-xl.safetensors")?;
115
116 builder.1.lora_models(
117 lora_path.parent().unwrap(),
118 vec![LoraSpec {
119 file_name: "pixel-art-xl".to_string(),
120 is_high_noise: false,
121 multiplier: 1.2,
122 }],
123 );
124 Ok(builder)
125}
126
127pub fn lora_pastelcomic_2_flux(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
129 let lora_path = download_file_hf_hub("nerijs/pastelcomic-flux", "pastelcomic_v2.safetensors")?;
130
131 builder.1.lora_models(
132 lora_path.parent().unwrap(),
133 vec![LoraSpec {
134 file_name: "pastelcomic_v2".to_string(),
135 is_high_noise: false,
136 multiplier: 1.2,
137 }],
138 );
139 Ok(builder)
140}
141
142pub fn lora_ghibli_flux(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
144 let lora_path = download_file_hf_hub(
145 "strangerzonehf/Ghibli-Flux-Cartoon-LoRA",
146 "Ghibili-Cartoon-Art.safetensors",
147 )?;
148
149 builder.1.lora_models(
150 lora_path.parent().unwrap(),
151 vec![LoraSpec {
152 file_name: "Ghibili-Cartoon-Art".to_string(),
153 is_high_noise: false,
154 multiplier: 1.0,
155 }],
156 );
157 Ok(builder)
158}
159
160pub fn lora_midjourney_mix_2_flux(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
162 let lora_path = download_file_hf_hub(
163 "strangerzonehf/Flux-Midjourney-Mix2-LoRA",
164 "mjV6.safetensors",
165 )?;
166
167 builder.1.lora_models(
168 lora_path.parent().unwrap(),
169 vec![LoraSpec {
170 file_name: "mjV6".to_string(),
171 is_high_noise: false,
172 multiplier: 1.0,
173 }],
174 );
175 Ok(builder)
176}
177
178pub fn lora_retro_pixel_flux(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
180 let lora_path = download_file_hf_hub(
181 "prithivMLmods/Retro-Pixel-Flux-LoRA",
182 "Retro-Pixel.safetensors",
183 )?;
184
185 builder.1.lora_models(
186 lora_path.parent().unwrap(),
187 vec![LoraSpec {
188 file_name: "Retro-Pixel".to_string(),
189 is_high_noise: false,
190 multiplier: 1.0,
191 }],
192 );
193 Ok(builder)
194}
195
196pub fn lora_canopus_pixar_3d_flux(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
198 let lora_path = download_file_hf_hub(
199 "prithivMLmods/Canopus-Pixar-3D-Flux-LoRA",
200 "Canopus-Pixar-3D-FluxDev-LoRA.safetensors",
201 )?;
202
203 builder.1.lora_models(
204 lora_path.parent().unwrap(),
205 vec![LoraSpec {
206 file_name: "Canopus-Pixar-3D-FluxDev-LoRA".to_string(),
207 is_high_noise: false,
208 multiplier: 1.0,
209 }],
210 );
211 Ok(builder)
212}
213
214pub fn t5xxl_fp8_flux_1(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
216 let t5xxl_path = download_file_hf_hub(
217 "comfyanonymous/flux_text_encoders",
218 "t5xxl_fp8_e4m3fn.safetensors",
219 )?;
220
221 builder.1.t5xxl(t5xxl_path);
222 Ok(builder)
223}
224
225pub fn t5xxl_fp16_flux_1(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
228 let t5xxl_path = download_file_hf_hub(
229 "comfyanonymous/flux_text_encoders",
230 "t5xxl_fp16.safetensors",
231 )?;
232
233 builder.1.t5xxl(t5xxl_path);
234 Ok(builder)
235}
236
237pub fn t5xxl_q2_k_flux_1(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
239 let t5xxl_path = download_file_hf_hub("Green-Sky/flux.1-schnell-GGUF", "t5xxl_q2_k.gguf")?;
240
241 builder.1.t5xxl(t5xxl_path);
242 Ok(builder)
243}
244
245pub fn t5xxl_q3_k_flux_1(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
247 let t5xxl_path = download_file_hf_hub("Green-Sky/flux.1-schnell-GGUF", "t5xxl_q3_k.gguf")?;
248
249 builder.1.t5xxl(t5xxl_path);
250 Ok(builder)
251}
252
253pub fn t5xxl_q4_k_flux_1(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
256 let t5xxl_path = download_file_hf_hub("Green-Sky/flux.1-schnell-GGUF", "t5xxl_q4_k.gguf")?;
257
258 builder.1.t5xxl(t5xxl_path);
259 Ok(builder)
260}
261
262pub fn t5xxl_q8_0_flux_1(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
264 let t5xxl_path = download_file_hf_hub("Green-Sky/flux.1-schnell-GGUF", "t5xxl_q8_0.gguf")?;
265
266 builder.1.t5xxl(t5xxl_path);
267 Ok(builder)
268}
269
270pub fn offload_params_to_cpu(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
272 let params: HashMap<_, _> = Module::iter()
273 .map(|module| (module, BackendDevice::CPU))
274 .collect();
275 builder.1.params_backend(params);
277 Ok(builder)
278}
279
280pub fn lazily_load_params_from_disk(
282 mut builder: ConfigsBuilder,
283) -> Result<ConfigsBuilder, ApiError> {
284 let params: HashMap<_, _> = Module::iter()
285 .map(|module| (module, BackendDevice::DISK))
286 .collect();
287 builder.1.params_backend(params);
289 Ok(builder)
290}
291
292pub fn lcm_lora_ssd_1b(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
295 let lora_path = download_file_hf_hub(
296 "kylielee505/mylcmlorassd",
297 "pytorch_lora_weights.safetensors",
298 )?;
299 builder.1.lora_models(
300 lora_path.parent().unwrap(),
301 vec![LoraSpec {
302 file_name: "pytorch_lora_weights".to_string(),
303 is_high_noise: false,
304 multiplier: 1.0,
305 }],
306 );
307 builder.0.cfg_scale(1.).steps(8);
308 Ok(builder)
309}
310
311pub fn vae_tiling(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
313 builder.1.vae_tiling(true);
314 Ok(builder)
315}
316
317pub fn preview_proj(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
319 builder.0.preview_mode(PreviewType::PREVIEW_PROJ);
320 Ok(builder)
321}
322
323pub fn preview_tae(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
325 builder.0.preview_mode(PreviewType::PREVIEW_TAE);
326 Ok(builder)
327}
328
329pub fn preview_vae(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
331 builder.0.preview_mode(PreviewType::PREVIEW_VAE);
332 Ok(builder)
333}
334
335pub fn enable_flash_attention(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
337 builder.1.flash_attention(true);
338 Ok(builder)
339}
340
341pub fn lcm_lora_segmind_vega_rt(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
343 let lora_path =
344 download_file_hf_hub("segmind/Segmind-VegaRT", "pytorch_lora_weights.safetensors")?;
345 builder.1.lora_models(
346 lora_path.parent().unwrap(),
347 vec![LoraSpec {
348 file_name: "pytorch_lora_weights".to_string(),
349 is_high_noise: false,
350 multiplier: 1.0,
351 }],
352 );
353 builder.0.guidance(0.).steps(4);
354 Ok(builder)
355}
356
357pub fn lora_anima_8_steps_turbo(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
359 let lora_path = download_file_hf_hub(
360 "Einhorn/Anima-Preview_8_Step_Turbo_Lora",
361 "Anima-Preview_Turbo_8step.safetensors",
362 )?;
363
364 builder.1.lora_models(
365 lora_path.parent().unwrap(),
366 vec![LoraSpec {
367 file_name: "Anima-Preview_Turbo_8step".to_string(),
368 is_high_noise: false,
369 multiplier: 1.0,
370 }],
371 );
372 builder.0.cfg_scale(1.).steps(8);
373 Ok(builder)
374}
375
376pub fn flux_2_small_decoder(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
378 let vae_path = download_file_hf_hub(
379 "black-forest-labs/FLUX.2-small-decoder",
380 "full_encoder_small_decoder.safetensors",
381 )?;
382 builder.1.vae(vae_path);
383 Ok(builder)
384}
385
386#[cfg(test)]
387mod tests {
388 use hf_hub::api::sync::ApiError;
389
390 use crate::{
391 api::gen_img,
392 modifier::{
393 enable_flash_attention, flux_2_small_decoder, lcm_lora_segmind_vega_rt,
394 lcm_lora_ssd_1b, lora_anima_8_steps_turbo, lora_canopus_pixar_3d_flux,
395 lora_ghibli_flux, lora_midjourney_mix_2_flux, lora_pastelcomic_2_flux,
396 lora_pixel_art_sdxl_base_1_0, lora_retro_pixel_flux, offload_params_to_cpu,
397 preview_proj, preview_tae, preview_vae, vae_tiling,
398 },
399 preset::{
400 AnimaWeight, ConfigsBuilder, Flux1Weight, Flux2Klein4BWeight, Preset, PresetBuilder,
401 },
402 util::set_hf_token,
403 };
404
405 use super::{
406 hybrid_taesd, hybrid_taesd_xl, lcm_lora_sd_1_5, lcm_lora_sdxl_base_1_0, taesd, taesd_xl,
407 };
408
409 static PROMPT: &str = "a lovely dinosaur made by crochet";
410
411 fn run<F>(preset: Preset, prompt: &str, m: F)
412 where
413 F: FnOnce(ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> + 'static,
414 {
415 let (mut config, mut model_config) = PresetBuilder::default()
416 .preset(preset)
417 .prompt(prompt)
418 .with_modifier(m)
419 .build()
420 .unwrap();
421 gen_img(&mut config, &mut model_config).unwrap();
422 }
423
424 #[ignore]
425 #[test]
426 fn test_taesd() {
427 run(Preset::StableDiffusion1_5, PROMPT, taesd);
428 }
429
430 #[ignore]
431 #[test]
432 fn test_taesd_xl() {
433 run(Preset::SDXLTurbo1_0, PROMPT, taesd_xl);
434 }
435
436 #[ignore]
437 #[test]
438 fn test_hybrid_taesd() {
439 run(Preset::StableDiffusion1_5, PROMPT, hybrid_taesd);
440 }
441
442 #[ignore]
443 #[test]
444 fn test_hybrid_taesd_xl() {
445 run(Preset::SDXLTurbo1_0, PROMPT, hybrid_taesd_xl);
446 }
447
448 #[ignore]
449 #[test]
450 fn test_lcm_lora_sd_1_5() {
451 run(Preset::StableDiffusion1_5, PROMPT, lcm_lora_sd_1_5);
452 }
453
454 #[ignore]
455 #[test]
456 fn test_lcm_lora_sdxl_base_1_0() {
457 run(Preset::SDXLBase1_0, PROMPT, lcm_lora_sdxl_base_1_0);
458 }
459
460 #[ignore]
461 #[test]
462 fn test_offload_params_to_cpu() {
463 set_hf_token(include_str!("../token.txt"));
464 run(
465 Preset::Flux1Schnell(Flux1Weight::Q2_K),
466 PROMPT,
467 offload_params_to_cpu,
468 );
469 }
470
471 #[ignore]
472 #[test]
473 fn test_lcm_lora_ssd_1b() {
474 run(
475 Preset::SSD1B(crate::preset::SSD1BWeight::F8_E4M3),
476 PROMPT,
477 lcm_lora_ssd_1b,
478 );
479 }
480
481 #[ignore]
482 #[test]
483 fn test_vae_tiling() {
484 run(
485 Preset::SSD1B(crate::preset::SSD1BWeight::F8_E4M3),
486 PROMPT,
487 vae_tiling,
488 );
489 }
490
491 #[ignore]
492 #[test]
493 fn test_preview_proj() {
494 run(Preset::SDXLTurbo1_0, PROMPT, preview_proj);
495 }
496
497 #[ignore]
498 #[test]
499 fn test_preview_tae() {
500 run(Preset::SDXLTurbo1_0, PROMPT, preview_tae);
501 }
502
503 #[ignore]
504 #[test]
505 fn test_preview_vae() {
506 run(Preset::SDXLTurbo1_0, PROMPT, preview_vae);
507 }
508
509 #[ignore]
510 #[test]
511 fn test_flash_attention() {
512 set_hf_token(include_str!("../token.txt"));
513 run(
514 Preset::Flux1Mini(crate::preset::Flux1MiniWeight::Q2_K),
515 PROMPT,
516 enable_flash_attention,
517 );
518 }
519
520 #[ignore]
521 #[test]
522 fn test_segmind_vega_rt_lcm_lora() {
523 run(Preset::SegmindVega, PROMPT, lcm_lora_segmind_vega_rt);
524 }
525
526 #[ignore]
527 #[test]
528 fn test_lora_pixel_art_xl() {
529 run(
530 Preset::SDXLBase1_0,
531 "pixel, a cute corgi",
532 lora_pixel_art_sdxl_base_1_0,
533 );
534 }
535
536 #[ignore]
537 #[test]
538 fn test_lora_pastelcomic_2_flux() {
539 set_hf_token(include_str!("../token.txt"));
540 run(
541 Preset::Flux1Schnell(Flux1Weight::Q2_K),
542 PROMPT,
543 lora_pastelcomic_2_flux,
544 );
545 }
546
547 #[ignore]
548 #[test]
549 fn test_lora_ghibli_flux() {
550 set_hf_token(include_str!("../token.txt"));
551 run(
552 Preset::Flux1Schnell(Flux1Weight::Q2_K),
553 "Ghibli Art – A wise old fisherman sits on a wooden dock, gazing out at the vast, blue ocean. He wears a worn-out straw hat and a navy-blue coat, and he holds a fishing rod in his hands. A black cat with bright green eyes sits beside him, watching the waves. In the distance, a lighthouse stands tall against the horizon, with seagulls soaring in the sky. The water glistens under the golden sunset.",
554 lora_ghibli_flux,
555 );
556 }
557
558 #[ignore]
559 #[test]
560 fn test_lora_midjourney_mix_2_flux() {
561 set_hf_token(include_str!("../token.txt"));
562 run(
563 Preset::Flux1Schnell(Flux1Weight::Q2_K),
564 "MJ v6, delicious dipped chocolate pastry japo gallery, white background, in the style of dark brown, close-up intensity, duckcore, rounded, high resolution --ar 2:3 --v 5",
565 lora_midjourney_mix_2_flux,
566 );
567 }
568
569 #[ignore]
570 #[test]
571 fn test_lora_retro_pixel_flux() {
572 set_hf_token(include_str!("../token.txt"));
573 run(
574 Preset::Flux1Schnell(Flux1Weight::Q2_K),
575 "Retro Pixel, pixel art of a Hamburger in the style of an old video game, hero, pixelated 8bit, final boss ",
576 lora_retro_pixel_flux,
577 );
578 }
579
580 #[ignore]
581 #[test]
582 fn test_lora_canopus_pixar_3d_flux() {
583 set_hf_token(include_str!("../token.txt"));
584 run(
585 Preset::Flux1Schnell(Flux1Weight::Q2_K),
586 "A young man with light brown wavy hair and light brown eyes sitting in an armchair and looking directly at the camera, pixar style, disney pixar, office background, ultra detailed, 1 man",
587 lora_canopus_pixar_3d_flux,
588 );
589 }
590
591 #[ignore]
592 #[test]
593 fn test_lora_anima_8_steps_turbo() {
594 run(
595 Preset::Anima(AnimaWeight::Q6_K),
596 PROMPT,
597 lora_anima_8_steps_turbo,
598 );
599 }
600
601 #[ignore]
602 #[test]
603 fn test_flux_2_small_decoder() {
604 run(
605 Preset::Flux2Klein4B(Flux2Klein4BWeight::Q8_0),
606 PROMPT,
607 flux_2_small_decoder,
608 );
609 }
610}