Skip to main content

diffusion_rs/
modifier.rs

1use hf_hub::api::sync::ApiError;
2
3use crate::{
4    api::{LoraSpec, PreviewType, SampleMethod},
5    preset::ConfigsBuilder,
6    util::download_file_hf_hub,
7};
8
9/// Add the <https://huggingface.co/ximso/RealESRGAN_x4plus_anime_6B> upscaler
10pub fn real_esrgan_x4plus_anime_6_b(
11    mut builder: ConfigsBuilder,
12) -> Result<ConfigsBuilder, ApiError> {
13    let upscaler_path = download_file_hf_hub(
14        "ximso/RealESRGAN_x4plus_anime_6B",
15        "RealESRGAN_x4plus_anime_6B.pth",
16    )?;
17    builder.1.upscale_model(upscaler_path);
18    Ok(builder)
19}
20
21/// Apply <https://huggingface.co/madebyollin/sdxl-vae-fp16-fix> to avoid black images with xl models
22pub fn sdxl_vae_fp16_fix(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
23    let vae_path = download_file_hf_hub("madebyollin/sdxl-vae-fp16-fix", "sdxl.vae.safetensors")?;
24    builder.1.vae(vae_path);
25    Ok(builder)
26}
27
28/// Apply <https://huggingface.co/madebyollin/taesd> taesd autoencoder for faster decoding (SD v1/v2)
29pub fn taesd(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
30    let taesd_path =
31        download_file_hf_hub("madebyollin/taesd", "diffusion_pytorch_model.safetensors")?;
32    builder.1.taesd(taesd_path);
33    Ok(builder)
34}
35
36/// Apply <https://huggingface.co/madebyollin/taesdxl> taesd autoencoder for faster decoding (SDXL)
37pub fn taesd_xl(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
38    let taesd_path =
39        download_file_hf_hub("madebyollin/taesdxl", "diffusion_pytorch_model.safetensors")?;
40    builder.1.taesd(taesd_path);
41    Ok(builder)
42}
43
44/// Apply <https://huggingface.co/cqyan/hybrid-sd-tinyvae> taesd autoencoder for faster decoding (SD v1/v2)
45pub fn hybrid_taesd(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
46    let taesd_path = download_file_hf_hub(
47        "cqyan/hybrid-sd-tinyvae",
48        "diffusion_pytorch_model.safetensors",
49    )?;
50    builder.1.taesd(taesd_path);
51    Ok(builder)
52}
53
54/// Apply <https://huggingface.co/cqyan/hybrid-sd-tinyvae-xl> taesd autoencoder for faster decoding (SDXL)
55pub fn hybrid_taesd_xl(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
56    let taesd_path = download_file_hf_hub(
57        "cqyan/hybrid-sd-tinyvae-xl",
58        "diffusion_pytorch_model.safetensors",
59    )?;
60    builder.1.taesd(taesd_path);
61    Ok(builder)
62}
63
64/// Apply <https://huggingface.co/latent-consistency/lcm-lora-sdv1-5> to reduce inference steps for SD v1 between 2-8 (default 8)
65/// cfg_scale 1. 8 steps.
66pub fn lcm_lora_sd_1_5(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
67    let lora_path = download_file_hf_hub(
68        "latent-consistency/lcm-lora-sdv1-5",
69        "pytorch_lora_weights.safetensors",
70    )?;
71    builder.1.lora_models(
72        lora_path.parent().unwrap(),
73        vec![LoraSpec {
74            file_name: "pytorch_lora_weights".to_string(),
75            is_high_noise: false,
76            multiplier: 1.0,
77        }],
78    );
79    builder.0.cfg_scale(1.).steps(8);
80    Ok(builder)
81}
82
83/// Apply <https://huggingface.co/latent-consistency/lcm-lora-sdxl> to reduce inference steps for SD v1 between 2-8 (default 8)
84/// Enabled [SampleMethod::LCM_SAMPLE_METHOD]. cfg_scale 2. 8 steps.
85pub fn lcm_lora_sdxl_base_1_0(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
86    let lora_path = download_file_hf_hub(
87        "latent-consistency/lcm-lora-sdxl",
88        "pytorch_lora_weights.safetensors",
89    )?;
90
91    builder.1.lora_models(
92        lora_path.parent().unwrap(),
93        vec![LoraSpec {
94            file_name: "pytorch_lora_weights".to_string(),
95            is_high_noise: false,
96            multiplier: 1.0,
97        }],
98    );
99    builder
100        .0
101        .cfg_scale(2.)
102        .steps(8)
103        .sampling_method(SampleMethod::LCM_SAMPLE_METHOD);
104    Ok(builder)
105}
106
107/// Apply <https://huggingface.co/nerijs/pixel-art-xl>
108pub fn lora_pixel_art_sdxl_base_1_0(
109    mut builder: ConfigsBuilder,
110) -> Result<ConfigsBuilder, ApiError> {
111    let lora_path = download_file_hf_hub("nerijs/pixel-art-xl", "pixel-art-xl.safetensors")?;
112
113    builder.1.lora_models(
114        lora_path.parent().unwrap(),
115        vec![LoraSpec {
116            file_name: "pixel-art-xl".to_string(),
117            is_high_noise: false,
118            multiplier: 1.2,
119        }],
120    );
121    Ok(builder)
122}
123
124/// Apply <https://huggingface.co/nerijs/pastelcomic-flux>
125pub fn lora_pastelcomic_2_flux(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
126    let lora_path = download_file_hf_hub("nerijs/pastelcomic-flux", "pastelcomic_v2.safetensors")?;
127
128    builder.1.lora_models(
129        lora_path.parent().unwrap(),
130        vec![LoraSpec {
131            file_name: "pastelcomic_v2".to_string(),
132            is_high_noise: false,
133            multiplier: 1.2,
134        }],
135    );
136    Ok(builder)
137}
138
139/// Apply <https://huggingface.co/strangerzonehf/Ghibli-Flux-Cartoon-LoRA>
140pub fn lora_ghibli_flux(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
141    let lora_path = download_file_hf_hub(
142        "strangerzonehf/Ghibli-Flux-Cartoon-LoRA",
143        "Ghibili-Cartoon-Art.safetensors",
144    )?;
145
146    builder.1.lora_models(
147        lora_path.parent().unwrap(),
148        vec![LoraSpec {
149            file_name: "Ghibili-Cartoon-Art".to_string(),
150            is_high_noise: false,
151            multiplier: 1.0,
152        }],
153    );
154    Ok(builder)
155}
156
157/// Apply <https://huggingface.co/strangerzonehf/Flux-Midjourney-Mix2-LoRA>
158pub fn lora_midjourney_mix_2_flux(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
159    let lora_path = download_file_hf_hub(
160        "strangerzonehf/Flux-Midjourney-Mix2-LoRA",
161        "mjV6.safetensors",
162    )?;
163
164    builder.1.lora_models(
165        lora_path.parent().unwrap(),
166        vec![LoraSpec {
167            file_name: "mjV6".to_string(),
168            is_high_noise: false,
169            multiplier: 1.0,
170        }],
171    );
172    Ok(builder)
173}
174
175/// Apply <https://huggingface.co/prithivMLmods/Retro-Pixel-Flux-LoRA>
176pub fn lora_retro_pixel_flux(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
177    let lora_path = download_file_hf_hub(
178        "prithivMLmods/Retro-Pixel-Flux-LoRA",
179        "Retro-Pixel.safetensors",
180    )?;
181
182    builder.1.lora_models(
183        lora_path.parent().unwrap(),
184        vec![LoraSpec {
185            file_name: "Retro-Pixel".to_string(),
186            is_high_noise: false,
187            multiplier: 1.0,
188        }],
189    );
190    Ok(builder)
191}
192
193/// Apply <https://huggingface.co/prithivMLmods/Canopus-Pixar-3D-Flux-LoRA>
194pub fn lora_canopus_pixar_3d_flux(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
195    let lora_path = download_file_hf_hub(
196        "prithivMLmods/Canopus-Pixar-3D-Flux-LoRA",
197        "Canopus-Pixar-3D-FluxDev-LoRA.safetensors",
198    )?;
199
200    builder.1.lora_models(
201        lora_path.parent().unwrap(),
202        vec![LoraSpec {
203            file_name: "Canopus-Pixar-3D-FluxDev-LoRA".to_string(),
204            is_high_noise: false,
205            multiplier: 1.0,
206        }],
207    );
208    Ok(builder)
209}
210
211/// Apply <https://huggingface.co/comfyanonymous/flux_text_encoders/blob/main/t5xxl_fp8_e4m3fn.safetensors> fp8_e4m3fn t5xxl text encoder to reduce memory usage
212pub fn t5xxl_fp8_flux_1(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
213    let t5xxl_path = download_file_hf_hub(
214        "comfyanonymous/flux_text_encoders",
215        "t5xxl_fp8_e4m3fn.safetensors",
216    )?;
217
218    builder.1.t5xxl(t5xxl_path);
219    Ok(builder)
220}
221
222/// Apply <https://huggingface.co/comfyanonymous/flux_text_encoders/blob/main/t5xxl_fp16.safetensors>
223/// Default for flux_1_dev/schnell
224pub fn t5xxl_fp16_flux_1(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
225    let t5xxl_path = download_file_hf_hub(
226        "comfyanonymous/flux_text_encoders",
227        "t5xxl_fp16.safetensors",
228    )?;
229
230    builder.1.t5xxl(t5xxl_path);
231    Ok(builder)
232}
233
234/// Apply <https://huggingface.co/Green-Sky/flux.1-schnell-GGUF/blob/main/t5xxl_q2_k.gguf>
235pub fn t5xxl_q2_k_flux_1(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
236    let t5xxl_path = download_file_hf_hub("Green-Sky/flux.1-schnell-GGUF", "t5xxl_q2_k.gguf")?;
237
238    builder.1.t5xxl(t5xxl_path);
239    Ok(builder)
240}
241
242/// Apply <https://huggingface.co/Green-Sky/flux.1-schnell-GGUF/blob/main/t5xxl_q3_k.gguf>
243pub fn t5xxl_q3_k_flux_1(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
244    let t5xxl_path = download_file_hf_hub("Green-Sky/flux.1-schnell-GGUF", "t5xxl_q3_k.gguf")?;
245
246    builder.1.t5xxl(t5xxl_path);
247    Ok(builder)
248}
249
250/// Apply <https://huggingface.co/Green-Sky/flux.1-schnell-GGUF/blob/main/t5xxl_q4_k.gguf>
251/// Default for flux_1_mini
252pub fn t5xxl_q4_k_flux_1(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
253    let t5xxl_path = download_file_hf_hub("Green-Sky/flux.1-schnell-GGUF", "t5xxl_q4_k.gguf")?;
254
255    builder.1.t5xxl(t5xxl_path);
256    Ok(builder)
257}
258
259/// Apply <https://huggingface.co/Green-Sky/flux.1-schnell-GGUF/blob/main/t5xxl_q8_0.gguf>
260pub fn t5xxl_q8_0_flux_1(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
261    let t5xxl_path = download_file_hf_hub("Green-Sky/flux.1-schnell-GGUF", "t5xxl_q8_0.gguf")?;
262
263    builder.1.t5xxl(t5xxl_path);
264    Ok(builder)
265}
266
267/// Offload model parameters to CPU (for low VRAM GPUs)
268pub fn offload_params_to_cpu(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
269    builder.1.offload_params_to_cpu(true);
270    Ok(builder)
271}
272
273/// Apply <https://huggingface.co/kylielee505/mylcmlorassd> to reduce inference steps for SD v1 between 2-8 (default 8)
274/// cfg_scale 1. 8 steps.
275pub fn lcm_lora_ssd_1b(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
276    let lora_path = download_file_hf_hub(
277        "kylielee505/mylcmlorassd",
278        "pytorch_lora_weights.safetensors",
279    )?;
280    builder.1.lora_models(
281        lora_path.parent().unwrap(),
282        vec![LoraSpec {
283            file_name: "pytorch_lora_weights".to_string(),
284            is_high_noise: false,
285            multiplier: 1.0,
286        }],
287    );
288    builder.0.cfg_scale(1.).steps(8);
289    Ok(builder)
290}
291
292/// Enable vae tiling
293pub fn vae_tiling(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
294    builder.1.vae_tiling(true);
295    Ok(builder)
296}
297
298/// Enable preview with [crate::api::PreviewType::PREVIEW_PROJ]
299pub fn preview_proj(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
300    builder.0.preview_mode(PreviewType::PREVIEW_PROJ);
301    Ok(builder)
302}
303
304/// Enable preview with [crate::api::PreviewType::PREVIEW_TAE]
305pub fn preview_tae(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
306    builder.0.preview_mode(PreviewType::PREVIEW_TAE);
307    Ok(builder)
308}
309
310/// Enable preview with [crate::api::PreviewType::PREVIEW_VAE]
311pub fn preview_vae(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
312    builder.0.preview_mode(PreviewType::PREVIEW_VAE);
313    Ok(builder)
314}
315
316/// Enable flash attention
317pub fn enable_flash_attention(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
318    builder.1.flash_attention(true);
319    Ok(builder)
320}
321
322/// Apply <https://huggingface.co/segmind/Segmind-VegaRT> to [crate::preset::Preset::SegmindVega]
323pub fn lcm_lora_segmind_vega_rt(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
324    let lora_path =
325        download_file_hf_hub("segmind/Segmind-VegaRT", "pytorch_lora_weights.safetensors")?;
326    builder.1.lora_models(
327        lora_path.parent().unwrap(),
328        vec![LoraSpec {
329            file_name: "pytorch_lora_weights".to_string(),
330            is_high_noise: false,
331            multiplier: 1.0,
332        }],
333    );
334    builder.0.guidance(0.).steps(4);
335    Ok(builder)
336}
337
338/// Apply <https://huggingface.co/Einhorn/Anima-Preview_8_Step_Turbo_Lora>
339pub fn lora_anima_8_steps_turbo(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
340    let lora_path = download_file_hf_hub(
341        "Einhorn/Anima-Preview_8_Step_Turbo_Lora",
342        "Anima-Preview_Turbo_8step.safetensors",
343    )?;
344
345    builder.1.lora_models(
346        lora_path.parent().unwrap(),
347        vec![LoraSpec {
348            file_name: "Anima-Preview_Turbo_8step".to_string(),
349            is_high_noise: false,
350            multiplier: 1.0,
351        }],
352    );
353    builder.0.cfg_scale(1.).steps(8);
354    Ok(builder)
355}
356
357#[cfg(test)]
358mod tests {
359    use hf_hub::api::sync::ApiError;
360
361    use crate::{
362        api::gen_img,
363        modifier::{
364            enable_flash_attention, lcm_lora_segmind_vega_rt, lcm_lora_ssd_1b,
365            lora_anima_8_steps_turbo, lora_canopus_pixar_3d_flux, lora_ghibli_flux,
366            lora_midjourney_mix_2_flux, lora_pastelcomic_2_flux, lora_pixel_art_sdxl_base_1_0,
367            lora_retro_pixel_flux, offload_params_to_cpu, preview_proj, preview_tae, preview_vae,
368            vae_tiling,
369        },
370        preset::{AnimaWeight, ConfigsBuilder, Flux1Weight, Preset, PresetBuilder},
371        util::set_hf_token,
372    };
373
374    use super::{
375        hybrid_taesd, hybrid_taesd_xl, lcm_lora_sd_1_5, lcm_lora_sdxl_base_1_0, taesd, taesd_xl,
376    };
377
378    static PROMPT: &str = "a lovely dinosaur made by crochet";
379
380    fn run<F>(preset: Preset, prompt: &str, m: F)
381    where
382        F: FnOnce(ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> + 'static,
383    {
384        let (mut config, mut model_config) = PresetBuilder::default()
385            .preset(preset)
386            .prompt(prompt)
387            .with_modifier(m)
388            .build()
389            .unwrap();
390        gen_img(&mut config, &mut model_config).unwrap();
391    }
392
393    #[ignore]
394    #[test]
395    fn test_taesd() {
396        run(Preset::StableDiffusion1_5, PROMPT, taesd);
397    }
398
399    #[ignore]
400    #[test]
401    fn test_taesd_xl() {
402        run(Preset::SDXLTurbo1_0, PROMPT, taesd_xl);
403    }
404
405    #[ignore]
406    #[test]
407    fn test_hybrid_taesd() {
408        run(Preset::StableDiffusion1_5, PROMPT, hybrid_taesd);
409    }
410
411    #[ignore]
412    #[test]
413    fn test_hybrid_taesd_xl() {
414        run(Preset::SDXLTurbo1_0, PROMPT, hybrid_taesd_xl);
415    }
416
417    #[ignore]
418    #[test]
419    fn test_lcm_lora_sd_1_5() {
420        run(Preset::StableDiffusion1_5, PROMPT, lcm_lora_sd_1_5);
421    }
422
423    #[ignore]
424    #[test]
425    fn test_lcm_lora_sdxl_base_1_0() {
426        run(Preset::SDXLBase1_0, PROMPT, lcm_lora_sdxl_base_1_0);
427    }
428
429    #[ignore]
430    #[test]
431    fn test_offload_params_to_cpu() {
432        set_hf_token(include_str!("../token.txt"));
433        run(
434            Preset::Flux1Schnell(Flux1Weight::Q2_K),
435            PROMPT,
436            offload_params_to_cpu,
437        );
438    }
439
440    #[ignore]
441    #[test]
442    fn test_lcm_lora_ssd_1b() {
443        run(
444            Preset::SSD1B(crate::preset::SSD1BWeight::F8_E4M3),
445            PROMPT,
446            lcm_lora_ssd_1b,
447        );
448    }
449
450    #[ignore]
451    #[test]
452    fn test_vae_tiling() {
453        run(
454            Preset::SSD1B(crate::preset::SSD1BWeight::F8_E4M3),
455            PROMPT,
456            vae_tiling,
457        );
458    }
459
460    #[ignore]
461    #[test]
462    fn test_preview_proj() {
463        run(Preset::SDXLTurbo1_0, PROMPT, preview_proj);
464    }
465
466    #[ignore]
467    #[test]
468    fn test_preview_tae() {
469        run(Preset::SDXLTurbo1_0, PROMPT, preview_tae);
470    }
471
472    #[ignore]
473    #[test]
474    fn test_preview_vae() {
475        run(Preset::SDXLTurbo1_0, PROMPT, preview_vae);
476    }
477
478    #[ignore]
479    #[test]
480    fn test_flash_attention() {
481        set_hf_token(include_str!("../token.txt"));
482        run(
483            Preset::Flux1Mini(crate::preset::Flux1MiniWeight::Q2_K),
484            PROMPT,
485            enable_flash_attention,
486        );
487    }
488
489    #[ignore]
490    #[test]
491    fn test_segmind_vega_rt_lcm_lora() {
492        run(Preset::SegmindVega, PROMPT, lcm_lora_segmind_vega_rt);
493    }
494
495    #[ignore]
496    #[test]
497    fn test_lora_pixel_art_xl() {
498        run(
499            Preset::SDXLBase1_0,
500            "pixel, a cute corgi",
501            lora_pixel_art_sdxl_base_1_0,
502        );
503    }
504
505    #[ignore]
506    #[test]
507    fn test_lora_pastelcomic_2_flux() {
508        set_hf_token(include_str!("../token.txt"));
509        run(
510            Preset::Flux1Schnell(Flux1Weight::Q2_K),
511            PROMPT,
512            lora_pastelcomic_2_flux,
513        );
514    }
515
516    #[ignore]
517    #[test]
518    fn test_lora_ghibli_flux() {
519        set_hf_token(include_str!("../token.txt"));
520        run(
521            Preset::Flux1Schnell(Flux1Weight::Q2_K),
522            "Ghibli Art – A wise old fisherman sits on a wooden dock, gazing out at the vast, blue ocean. He wears a worn-out straw hat and a navy-blue coat, and he holds a fishing rod in his hands. A black cat with bright green eyes sits beside him, watching the waves. In the distance, a lighthouse stands tall against the horizon, with seagulls soaring in the sky. The water glistens under the golden sunset.",
523            lora_ghibli_flux,
524        );
525    }
526
527    #[ignore]
528    #[test]
529    fn test_lora_midjourney_mix_2_flux() {
530        set_hf_token(include_str!("../token.txt"));
531        run(
532            Preset::Flux1Schnell(Flux1Weight::Q2_K),
533            "MJ v6, delicious dipped chocolate pastry japo gallery, white background, in the style of dark brown, close-up intensity, duckcore, rounded, high resolution --ar 2:3 --v 5",
534            lora_midjourney_mix_2_flux,
535        );
536    }
537
538    #[ignore]
539    #[test]
540    fn test_lora_retro_pixel_flux() {
541        set_hf_token(include_str!("../token.txt"));
542        run(
543            Preset::Flux1Schnell(Flux1Weight::Q2_K),
544            "Retro Pixel, pixel art of a Hamburger in the style of an old video game, hero, pixelated 8bit, final boss ",
545            lora_retro_pixel_flux,
546        );
547    }
548
549    #[ignore]
550    #[test]
551    fn test_lora_canopus_pixar_3d_flux() {
552        set_hf_token(include_str!("../token.txt"));
553        run(
554            Preset::Flux1Schnell(Flux1Weight::Q2_K),
555            "A young man with light brown wavy hair and light brown eyes sitting in an armchair and looking directly at the camera, pixar style, disney pixar, office background, ultra detailed, 1 man",
556            lora_canopus_pixar_3d_flux,
557        );
558    }
559
560    #[ignore]
561    #[test]
562    fn test_lora_anima_8_steps_turbo() {
563        run(
564            Preset::Anima(AnimaWeight::Q6_K),
565            PROMPT,
566            lora_anima_8_steps_turbo,
567        );
568    }
569}