diffusion_rs/
modifier.rs

1use hf_hub::api::sync::ApiError;
2
3use crate::{api::SampleMethod, preset::ConfigsBuilder, util::download_file_hf_hub};
4
5/// Add the <https://huggingface.co/ximso/RealESRGAN_x4plus_anime_6B> upscaler
6pub fn real_esrgan_x4plus_anime_6_b(
7    mut builder: ConfigsBuilder,
8) -> Result<ConfigsBuilder, ApiError> {
9    let upscaler_path = download_file_hf_hub(
10        "ximso/RealESRGAN_x4plus_anime_6B",
11        "RealESRGAN_x4plus_anime_6B.pth",
12    )?;
13    builder.1.upscale_model(upscaler_path);
14    Ok(builder)
15}
16
17/// Apply <https://huggingface.co/madebyollin/sdxl-vae-fp16-fix> to avoid black images with xl models
18pub fn sdxl_vae_fp16_fix(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
19    let vae_path = download_file_hf_hub("madebyollin/sdxl-vae-fp16-fix", "sdxl.vae.safetensors")?;
20    builder.1.vae(vae_path);
21    Ok(builder)
22}
23
24/// Apply <https://huggingface.co/madebyollin/taesd> taesd autoencoder for faster decoding (SD v1/v2)
25pub fn taesd(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
26    let taesd_path =
27        download_file_hf_hub("madebyollin/taesd", "diffusion_pytorch_model.safetensors")?;
28    builder.1.taesd(taesd_path);
29    Ok(builder)
30}
31
32/// Apply <https://huggingface.co/madebyollin/taesdxl> taesd autoencoder for faster decoding (SDXL)
33pub fn taesd_xl(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
34    let taesd_path =
35        download_file_hf_hub("madebyollin/taesdxl", "diffusion_pytorch_model.safetensors")?;
36    builder.1.taesd(taesd_path);
37    Ok(builder)
38}
39
40/// Apply <https://huggingface.co/cqyan/hybrid-sd-tinyvae> taesd autoencoder for faster decoding (SD v1/v2)
41pub fn hybrid_taesd(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
42    let taesd_path = download_file_hf_hub(
43        "cqyan/hybrid-sd-tinyvae",
44        "diffusion_pytorch_model.safetensors",
45    )?;
46    builder.1.taesd(taesd_path);
47    Ok(builder)
48}
49
50/// Apply <https://huggingface.co/cqyan/hybrid-sd-tinyvae-xl> taesd autoencoder for faster decoding (SDXL)
51pub fn hybrid_taesd_xl(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
52    let taesd_path = download_file_hf_hub(
53        "cqyan/hybrid-sd-tinyvae-xl",
54        "diffusion_pytorch_model.safetensors",
55    )?;
56    builder.1.taesd(taesd_path);
57    Ok(builder)
58}
59
60/// Apply <https://huggingface.co/latent-consistency/lcm-lora-sdv1-5> to reduce inference steps for SD v1 between 2-8
61/// cfg_scale 1. 4 steps.
62pub fn lcm_lora_sd_1_5(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
63    let lora_path = download_file_hf_hub(
64        "latent-consistency/lcm-lora-sdv1-5",
65        "pytorch_lora_weights.safetensors",
66    )?;
67    builder.1.lora_model(&lora_path);
68    builder.0.cfg_scale(1.).steps(4);
69    Ok(builder)
70}
71
72/// Apply <https://huggingface.co/latent-consistency/lcm-lora-sdxl> to reduce inference steps for SD v1 between 2-8 (default 8)
73/// Enabled [SampleMethod::LCM]. cfg_scale 2. 8 steps.
74pub fn lcm_lora_sdxl_base_1_0(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
75    let lora_path = download_file_hf_hub(
76        "latent-consistency/lcm-lora-sdxl",
77        "pytorch_lora_weights.safetensors",
78    )?;
79    builder.1.lora_model(&lora_path);
80    builder
81        .0
82        .cfg_scale(2.)
83        .steps(8)
84        .sampling_method(SampleMethod::LCM);
85    Ok(builder)
86}
87
88/// Apply <https://huggingface.co/comfyanonymous/flux_text_encoders/blob/main/t5xxl_fp8_e4m3fn.safetensors> Fp8 t5xxl text encoder to reduce memory usage
89pub fn t5xxl_fp8_flux_1(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
90    let t5xxl_path = download_file_hf_hub(
91        "comfyanonymous/flux_text_encoders",
92        "t5xxl_fp8_e4m3fn.safetensors",
93    )?;
94
95    builder.1.t5xxl(t5xxl_path);
96    Ok(builder)
97}
98
99/// Apply <https://huggingface.co/comfyanonymous/flux_text_encoders/blob/main/t5xxl_fp16.safetensors>
100/// Default for flux_1_dev/schnell
101pub fn t5xxl_fp16_flux_1(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
102    let t5xxl_path = download_file_hf_hub(
103        "comfyanonymous/flux_text_encoders",
104        "t5xxl_fp16.safetensors",
105    )?;
106
107    builder.1.t5xxl(t5xxl_path);
108    Ok(builder)
109}
110
111/// Apply <https://huggingface.co/Green-Sky/flux.1-schnell-GGUF/blob/main/t5xxl_q2_k.gguf>
112pub fn t5xxl_q2_k_flux_1(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
113    let t5xxl_path = download_file_hf_hub("Green-Sky/flux.1-schnell-GGUF", "t5xxl_q2_k.gguf")?;
114
115    builder.1.t5xxl(t5xxl_path);
116    Ok(builder)
117}
118
119/// Apply <https://huggingface.co/Green-Sky/flux.1-schnell-GGUF/blob/main/t5xxl_q3_k.gguf>
120pub fn t5xxl_q3_k_flux_1(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
121    let t5xxl_path = download_file_hf_hub("Green-Sky/flux.1-schnell-GGUF", "t5xxl_q3_k.gguf")?;
122
123    builder.1.t5xxl(t5xxl_path);
124    Ok(builder)
125}
126
127/// Apply <https://huggingface.co/Green-Sky/flux.1-schnell-GGUF/blob/main/t5xxl_q4_k.gguf>
128/// Default for flux_1_mini
129pub fn t5xxl_q4_k_flux_1(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
130    let t5xxl_path = download_file_hf_hub("Green-Sky/flux.1-schnell-GGUF", "t5xxl_q4_k.gguf")?;
131
132    builder.1.t5xxl(t5xxl_path);
133    Ok(builder)
134}
135
136/// Apply <https://huggingface.co/Green-Sky/flux.1-schnell-GGUF/blob/main/t5xxl_q8_0.gguf>
137pub fn t5xxl_q8_0_flux_1(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
138    let t5xxl_path = download_file_hf_hub("Green-Sky/flux.1-schnell-GGUF", "t5xxl_q8_0.gguf")?;
139
140    builder.1.t5xxl(t5xxl_path);
141    Ok(builder)
142}
143
144pub fn offload_params_to_cpu(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
145    builder.1.offload_params_to_cpu(true);
146    Ok(builder)
147}
148
149#[cfg(test)]
150mod tests {
151    use crate::{
152        api::{self, gen_img},
153        modifier::offload_params_to_cpu,
154        preset::{Modifier, Preset, PresetBuilder},
155        util::set_hf_token,
156    };
157
158    use super::{
159        hybrid_taesd, hybrid_taesd_xl, lcm_lora_sd_1_5, lcm_lora_sdxl_base_1_0, taesd, taesd_xl,
160    };
161
162    static PROMPT: &str = "a lovely cat holding a sign says 'diffusion-rs'";
163
164    fn run(preset: Preset, m: Modifier) {
165        let (mut config, mut model_config) = PresetBuilder::default()
166            .preset(preset)
167            .prompt(PROMPT)
168            .with_modifier(m)
169            .build()
170            .unwrap();
171        gen_img(&mut config, &mut model_config).unwrap();
172    }
173
174    #[ignore]
175    #[test]
176    fn test_taesd() {
177        run(Preset::StableDiffusion1_5, taesd);
178    }
179
180    #[ignore]
181    #[test]
182    fn test_taesd_xl() {
183        run(Preset::SDXLTurbo1_0Fp16, taesd_xl);
184    }
185
186    #[ignore]
187    #[test]
188    fn test_hybrid_taesd() {
189        run(Preset::StableDiffusion1_5, hybrid_taesd);
190    }
191
192    #[ignore]
193    #[test]
194    fn test_hybrid_taesd_xl() {
195        run(Preset::SDXLTurbo1_0Fp16, hybrid_taesd_xl);
196    }
197
198    #[ignore]
199    #[test]
200    fn test_lcm_lora_sd_1_5() {
201        run(Preset::StableDiffusion1_5, lcm_lora_sd_1_5);
202    }
203
204    #[ignore]
205    #[test]
206    fn test_lcm_lora_sdxl_base_1_0() {
207        run(Preset::SDXLBase1_0, lcm_lora_sdxl_base_1_0);
208    }
209
210    #[ignore]
211    #[test]
212    fn test_offload_params_to_cpu() {
213        set_hf_token(include_str!("../token.txt"));
214        run(
215            Preset::Flux1Schnell(api::WeightType::SD_TYPE_Q2_K),
216            offload_params_to_cpu,
217        );
218    }
219}