Skip to main content

diffusion_rs/
modifier.rs

1use hf_hub::api::sync::ApiError;
2
3use crate::{
4    api::{LoraSpec, PreviewType, SampleMethod},
5    preset::ConfigsBuilder,
6    util::download_file_hf_hub,
7};
8
9/// Add the <https://huggingface.co/ximso/RealESRGAN_x4plus_anime_6B> upscaler
10pub fn real_esrgan_x4plus_anime_6_b(
11    mut builder: ConfigsBuilder,
12) -> Result<ConfigsBuilder, ApiError> {
13    let upscaler_path = download_file_hf_hub(
14        "ximso/RealESRGAN_x4plus_anime_6B",
15        "RealESRGAN_x4plus_anime_6B.pth",
16    )?;
17    builder.1.upscale_model(upscaler_path);
18    Ok(builder)
19}
20
21/// Apply <https://huggingface.co/madebyollin/sdxl-vae-fp16-fix> to avoid black images with xl models
22pub fn sdxl_vae_fp16_fix(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
23    let vae_path = download_file_hf_hub("madebyollin/sdxl-vae-fp16-fix", "sdxl.vae.safetensors")?;
24    builder.1.vae(vae_path);
25    Ok(builder)
26}
27
28/// Apply <https://huggingface.co/madebyollin/taesd> taesd autoencoder for faster decoding (SD v1/v2)
29pub fn taesd(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
30    let taesd_path =
31        download_file_hf_hub("madebyollin/taesd", "diffusion_pytorch_model.safetensors")?;
32    builder.1.taesd(taesd_path);
33    Ok(builder)
34}
35
36/// Apply <https://huggingface.co/madebyollin/taesdxl> taesd autoencoder for faster decoding (SDXL)
37pub fn taesd_xl(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
38    let taesd_path =
39        download_file_hf_hub("madebyollin/taesdxl", "diffusion_pytorch_model.safetensors")?;
40    builder.1.taesd(taesd_path);
41    Ok(builder)
42}
43
44/// Apply <https://huggingface.co/cqyan/hybrid-sd-tinyvae> taesd autoencoder for faster decoding (SD v1/v2)
45pub fn hybrid_taesd(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
46    let taesd_path = download_file_hf_hub(
47        "cqyan/hybrid-sd-tinyvae",
48        "diffusion_pytorch_model.safetensors",
49    )?;
50    builder.1.taesd(taesd_path);
51    Ok(builder)
52}
53
54/// Apply <https://huggingface.co/cqyan/hybrid-sd-tinyvae-xl> taesd autoencoder for faster decoding (SDXL)
55pub fn hybrid_taesd_xl(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
56    let taesd_path = download_file_hf_hub(
57        "cqyan/hybrid-sd-tinyvae-xl",
58        "diffusion_pytorch_model.safetensors",
59    )?;
60    builder.1.taesd(taesd_path);
61    Ok(builder)
62}
63
64/// Apply <https://huggingface.co/latent-consistency/lcm-lora-sdv1-5> to reduce inference steps for SD v1 between 2-8 (default 8)
65/// cfg_scale 1. 8 steps.
66pub fn lcm_lora_sd_1_5(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
67    let lora_path = download_file_hf_hub(
68        "latent-consistency/lcm-lora-sdv1-5",
69        "pytorch_lora_weights.safetensors",
70    )?;
71    builder.1.lora_models(
72        lora_path.parent().unwrap(),
73        vec![LoraSpec {
74            file_name: "pytorch_lora_weights".to_string(),
75            is_high_noise: false,
76            multiplier: 1.0,
77        }],
78    );
79    builder.0.cfg_scale(1.).steps(8);
80    Ok(builder)
81}
82
83/// Apply <https://huggingface.co/latent-consistency/lcm-lora-sdxl> to reduce inference steps for SD v1 between 2-8 (default 8)
84/// Enabled [SampleMethod::LCM_SAMPLE_METHOD]. cfg_scale 2. 8 steps.
85pub fn lcm_lora_sdxl_base_1_0(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
86    let lora_path = download_file_hf_hub(
87        "latent-consistency/lcm-lora-sdxl",
88        "pytorch_lora_weights.safetensors",
89    )?;
90
91    builder.1.lora_models(
92        lora_path.parent().unwrap(),
93        vec![LoraSpec {
94            file_name: "pytorch_lora_weights".to_string(),
95            is_high_noise: false,
96            multiplier: 1.0,
97        }],
98    );
99    builder
100        .0
101        .cfg_scale(2.)
102        .steps(8)
103        .sampling_method(SampleMethod::LCM_SAMPLE_METHOD);
104    Ok(builder)
105}
106
107/// Apply <https://huggingface.co/comfyanonymous/flux_text_encoders/blob/main/t5xxl_fp8_e4m3fn.safetensors> fp8_e4m3fn t5xxl text encoder to reduce memory usage
108pub fn t5xxl_fp8_flux_1(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
109    let t5xxl_path = download_file_hf_hub(
110        "comfyanonymous/flux_text_encoders",
111        "t5xxl_fp8_e4m3fn.safetensors",
112    )?;
113
114    builder.1.t5xxl(t5xxl_path);
115    Ok(builder)
116}
117
118/// Apply <https://huggingface.co/comfyanonymous/flux_text_encoders/blob/main/t5xxl_fp16.safetensors>
119/// Default for flux_1_dev/schnell
120pub fn t5xxl_fp16_flux_1(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
121    let t5xxl_path = download_file_hf_hub(
122        "comfyanonymous/flux_text_encoders",
123        "t5xxl_fp16.safetensors",
124    )?;
125
126    builder.1.t5xxl(t5xxl_path);
127    Ok(builder)
128}
129
130/// Apply <https://huggingface.co/Green-Sky/flux.1-schnell-GGUF/blob/main/t5xxl_q2_k.gguf>
131pub fn t5xxl_q2_k_flux_1(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
132    let t5xxl_path = download_file_hf_hub("Green-Sky/flux.1-schnell-GGUF", "t5xxl_q2_k.gguf")?;
133
134    builder.1.t5xxl(t5xxl_path);
135    Ok(builder)
136}
137
138/// Apply <https://huggingface.co/Green-Sky/flux.1-schnell-GGUF/blob/main/t5xxl_q3_k.gguf>
139pub fn t5xxl_q3_k_flux_1(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
140    let t5xxl_path = download_file_hf_hub("Green-Sky/flux.1-schnell-GGUF", "t5xxl_q3_k.gguf")?;
141
142    builder.1.t5xxl(t5xxl_path);
143    Ok(builder)
144}
145
146/// Apply <https://huggingface.co/Green-Sky/flux.1-schnell-GGUF/blob/main/t5xxl_q4_k.gguf>
147/// Default for flux_1_mini
148pub fn t5xxl_q4_k_flux_1(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
149    let t5xxl_path = download_file_hf_hub("Green-Sky/flux.1-schnell-GGUF", "t5xxl_q4_k.gguf")?;
150
151    builder.1.t5xxl(t5xxl_path);
152    Ok(builder)
153}
154
155/// Apply <https://huggingface.co/Green-Sky/flux.1-schnell-GGUF/blob/main/t5xxl_q8_0.gguf>
156pub fn t5xxl_q8_0_flux_1(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
157    let t5xxl_path = download_file_hf_hub("Green-Sky/flux.1-schnell-GGUF", "t5xxl_q8_0.gguf")?;
158
159    builder.1.t5xxl(t5xxl_path);
160    Ok(builder)
161}
162
163/// Offload model parameters to CPU (for low VRAM GPUs)
164pub fn offload_params_to_cpu(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
165    builder.1.offload_params_to_cpu(true);
166    Ok(builder)
167}
168
169/// Apply <https://huggingface.co/kylielee505/mylcmlorassd> to reduce inference steps for SD v1 between 2-8 (default 8)
170/// cfg_scale 1. 8 steps.
171pub fn lcm_lora_ssd_1b(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
172    let lora_path = download_file_hf_hub(
173        "kylielee505/mylcmlorassd",
174        "pytorch_lora_weights.safetensors",
175    )?;
176    builder.1.lora_models(
177        lora_path.parent().unwrap(),
178        vec![LoraSpec {
179            file_name: "pytorch_lora_weights".to_string(),
180            is_high_noise: false,
181            multiplier: 1.0,
182        }],
183    );
184    builder.0.cfg_scale(1.).steps(8);
185    Ok(builder)
186}
187
188/// Enable vae tiling
189pub fn vae_tiling(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
190    builder.1.vae_tiling(true);
191    Ok(builder)
192}
193
194/// Enable preview with [crate::api::PreviewType::PREVIEW_PROJ]
195pub fn preview_proj(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
196    builder.0.preview_mode(PreviewType::PREVIEW_PROJ);
197    Ok(builder)
198}
199
200/// Enable preview with [crate::api::PreviewType::PREVIEW_TAE]
201pub fn preview_tae(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
202    builder.0.preview_mode(PreviewType::PREVIEW_TAE);
203    Ok(builder)
204}
205
206/// Enable preview with [crate::api::PreviewType::PREVIEW_VAE]
207pub fn preview_vae(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
208    builder.0.preview_mode(PreviewType::PREVIEW_VAE);
209    Ok(builder)
210}
211
212/// Enable flash attention
213pub fn enable_flash_attention(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
214    builder.1.flash_attention(true);
215    Ok(builder)
216}
217
218/// Apply <https://huggingface.co/segmind/Segmind-VegaRT> to [crate::preset::Preset::SegmindVega]
219pub fn segmind_vega_rt_lcm_lora(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
220    let lora_path =
221        download_file_hf_hub("segmind/Segmind-VegaRT", "pytorch_lora_weights.safetensors")?;
222    builder.1.lora_models(
223        lora_path.parent().unwrap(),
224        vec![LoraSpec {
225            file_name: "pytorch_lora_weights".to_string(),
226            is_high_noise: false,
227            multiplier: 1.0,
228        }],
229    );
230    builder.0.guidance(0.).steps(4);
231    Ok(builder)
232}
233
234#[cfg(test)]
235mod tests {
236    use hf_hub::api::sync::ApiError;
237
238    use crate::{
239        api::gen_img,
240        modifier::{
241            enable_flash_attention, lcm_lora_ssd_1b, offload_params_to_cpu, preview_proj,
242            preview_tae, preview_vae, segmind_vega_rt_lcm_lora, vae_tiling,
243        },
244        preset::{ConfigsBuilder, Flux1Weight, Preset, PresetBuilder},
245        util::set_hf_token,
246    };
247
248    use super::{
249        hybrid_taesd, hybrid_taesd_xl, lcm_lora_sd_1_5, lcm_lora_sdxl_base_1_0, taesd, taesd_xl,
250    };
251
252    static PROMPT: &str = "a lovely dinosaur made by crochet";
253
254    fn run<F>(preset: Preset, m: F)
255    where
256        F: FnOnce(ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> + 'static,
257    {
258        let (mut config, mut model_config) = PresetBuilder::default()
259            .preset(preset)
260            .prompt(PROMPT)
261            .with_modifier(m)
262            .build()
263            .unwrap();
264        gen_img(&mut config, &mut model_config).unwrap();
265    }
266
267    #[ignore]
268    #[test]
269    fn test_taesd() {
270        run(Preset::StableDiffusion1_5, taesd);
271    }
272
273    #[ignore]
274    #[test]
275    fn test_taesd_xl() {
276        run(Preset::SDXLTurbo1_0, taesd_xl);
277    }
278
279    #[ignore]
280    #[test]
281    fn test_hybrid_taesd() {
282        run(Preset::StableDiffusion1_5, hybrid_taesd);
283    }
284
285    #[ignore]
286    #[test]
287    fn test_hybrid_taesd_xl() {
288        run(Preset::SDXLTurbo1_0, hybrid_taesd_xl);
289    }
290
291    #[ignore]
292    #[test]
293    fn test_lcm_lora_sd_1_5() {
294        run(Preset::StableDiffusion1_5, lcm_lora_sd_1_5);
295    }
296
297    #[ignore]
298    #[test]
299    fn test_lcm_lora_sdxl_base_1_0() {
300        run(Preset::SDXLBase1_0, lcm_lora_sdxl_base_1_0);
301    }
302
303    #[ignore]
304    #[test]
305    fn test_offload_params_to_cpu() {
306        set_hf_token(include_str!("../token.txt"));
307        run(
308            Preset::Flux1Schnell(Flux1Weight::Q2_K),
309            offload_params_to_cpu,
310        );
311    }
312
313    #[ignore]
314    #[test]
315    fn test_lcm_lora_ssd_1b() {
316        run(
317            Preset::SSD1B(crate::preset::SSD1BWeight::F8_E4M3),
318            lcm_lora_ssd_1b,
319        );
320    }
321
322    #[ignore]
323    #[test]
324    fn test_vae_tiling() {
325        run(
326            Preset::SSD1B(crate::preset::SSD1BWeight::F8_E4M3),
327            vae_tiling,
328        );
329    }
330
331    #[ignore]
332    #[test]
333    fn test_preview_proj() {
334        run(Preset::SDXLTurbo1_0, preview_proj);
335    }
336
337    #[ignore]
338    #[test]
339    fn test_preview_tae() {
340        run(Preset::SDXLTurbo1_0, preview_tae);
341    }
342
343    #[ignore]
344    #[test]
345    fn test_preview_vae() {
346        run(Preset::SDXLTurbo1_0, preview_vae);
347    }
348
349    #[ignore]
350    #[test]
351    fn test_flash_attention() {
352        set_hf_token(include_str!("../token.txt"));
353        run(
354            Preset::Flux1Mini(crate::preset::Flux1MiniWeight::Q2_K),
355            enable_flash_attention,
356        );
357    }
358
359    #[ignore]
360    #[test]
361    fn test_segmind_vega_rt_lcm_lora() {
362        run(Preset::SegmindVega, segmind_vega_rt_lcm_lora);
363    }
364}