diffusion_rs/
modifier.rs

1use hf_hub::api::sync::ApiError;
2
3use crate::{
4    api::{LoraSpec, PreviewType, SampleMethod},
5    preset::ConfigsBuilder,
6    util::download_file_hf_hub,
7};
8
9/// Add the <https://huggingface.co/ximso/RealESRGAN_x4plus_anime_6B> upscaler
10pub fn real_esrgan_x4plus_anime_6_b(
11    mut builder: ConfigsBuilder,
12) -> Result<ConfigsBuilder, ApiError> {
13    let upscaler_path = download_file_hf_hub(
14        "ximso/RealESRGAN_x4plus_anime_6B",
15        "RealESRGAN_x4plus_anime_6B.pth",
16    )?;
17    builder.1.upscale_model(upscaler_path);
18    Ok(builder)
19}
20
21/// Apply <https://huggingface.co/madebyollin/sdxl-vae-fp16-fix> to avoid black images with xl models
22pub fn sdxl_vae_fp16_fix(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
23    let vae_path = download_file_hf_hub("madebyollin/sdxl-vae-fp16-fix", "sdxl.vae.safetensors")?;
24    builder.1.vae(vae_path);
25    Ok(builder)
26}
27
28/// Apply <https://huggingface.co/madebyollin/taesd> taesd autoencoder for faster decoding (SD v1/v2)
29pub fn taesd(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
30    let taesd_path =
31        download_file_hf_hub("madebyollin/taesd", "diffusion_pytorch_model.safetensors")?;
32    builder.1.taesd(taesd_path);
33    Ok(builder)
34}
35
36/// Apply <https://huggingface.co/madebyollin/taesdxl> taesd autoencoder for faster decoding (SDXL)
37pub fn taesd_xl(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
38    let taesd_path =
39        download_file_hf_hub("madebyollin/taesdxl", "diffusion_pytorch_model.safetensors")?;
40    builder.1.taesd(taesd_path);
41    Ok(builder)
42}
43
44/// Apply <https://huggingface.co/cqyan/hybrid-sd-tinyvae> taesd autoencoder for faster decoding (SD v1/v2)
45pub fn hybrid_taesd(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
46    let taesd_path = download_file_hf_hub(
47        "cqyan/hybrid-sd-tinyvae",
48        "diffusion_pytorch_model.safetensors",
49    )?;
50    builder.1.taesd(taesd_path);
51    Ok(builder)
52}
53
54/// Apply <https://huggingface.co/cqyan/hybrid-sd-tinyvae-xl> taesd autoencoder for faster decoding (SDXL)
55pub fn hybrid_taesd_xl(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
56    let taesd_path = download_file_hf_hub(
57        "cqyan/hybrid-sd-tinyvae-xl",
58        "diffusion_pytorch_model.safetensors",
59    )?;
60    builder.1.taesd(taesd_path);
61    Ok(builder)
62}
63
64/// Apply <https://huggingface.co/latent-consistency/lcm-lora-sdv1-5> to reduce inference steps for SD v1 between 2-8 (default 8)
65/// cfg_scale 1. 8 steps.
66pub fn lcm_lora_sd_1_5(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
67    let lora_path = download_file_hf_hub(
68        "latent-consistency/lcm-lora-sdv1-5",
69        "pytorch_lora_weights.safetensors",
70    )?;
71    builder.1.lora_models(
72        lora_path.parent().unwrap(),
73        vec![LoraSpec {
74            file_name: "pytorch_lora_weights".to_string(),
75            is_high_noise: false,
76            multiplier: 1.0,
77        }],
78    );
79    builder.0.cfg_scale(1.).steps(8);
80    Ok(builder)
81}
82
83/// Apply <https://huggingface.co/latent-consistency/lcm-lora-sdxl> to reduce inference steps for SD v1 between 2-8 (default 8)
84/// Enabled [SampleMethod::LCM_SAMPLE_METHOD]. cfg_scale 2. 8 steps.
85pub fn lcm_lora_sdxl_base_1_0(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
86    let lora_path = download_file_hf_hub(
87        "latent-consistency/lcm-lora-sdxl",
88        "pytorch_lora_weights.safetensors",
89    )?;
90
91    builder.1.lora_models(
92        lora_path.parent().unwrap(),
93        vec![LoraSpec {
94            file_name: "pytorch_lora_weights".to_string(),
95            is_high_noise: false,
96            multiplier: 1.0,
97        }],
98    );
99    builder
100        .0
101        .cfg_scale(2.)
102        .steps(8)
103        .sampling_method(SampleMethod::LCM_SAMPLE_METHOD);
104    Ok(builder)
105}
106
107/// Apply <https://huggingface.co/comfyanonymous/flux_text_encoders/blob/main/t5xxl_fp8_e4m3fn.safetensors> fp8_e4m3fn t5xxl text encoder to reduce memory usage
108pub fn t5xxl_fp8_flux_1(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
109    let t5xxl_path = download_file_hf_hub(
110        "comfyanonymous/flux_text_encoders",
111        "t5xxl_fp8_e4m3fn.safetensors",
112    )?;
113
114    builder.1.t5xxl(t5xxl_path);
115    Ok(builder)
116}
117
118/// Apply <https://huggingface.co/comfyanonymous/flux_text_encoders/blob/main/t5xxl_fp16.safetensors>
119/// Default for flux_1_dev/schnell
120pub fn t5xxl_fp16_flux_1(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
121    let t5xxl_path = download_file_hf_hub(
122        "comfyanonymous/flux_text_encoders",
123        "t5xxl_fp16.safetensors",
124    )?;
125
126    builder.1.t5xxl(t5xxl_path);
127    Ok(builder)
128}
129
130/// Apply <https://huggingface.co/Green-Sky/flux.1-schnell-GGUF/blob/main/t5xxl_q2_k.gguf>
131pub fn t5xxl_q2_k_flux_1(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
132    let t5xxl_path = download_file_hf_hub("Green-Sky/flux.1-schnell-GGUF", "t5xxl_q2_k.gguf")?;
133
134    builder.1.t5xxl(t5xxl_path);
135    Ok(builder)
136}
137
138/// Apply <https://huggingface.co/Green-Sky/flux.1-schnell-GGUF/blob/main/t5xxl_q3_k.gguf>
139pub fn t5xxl_q3_k_flux_1(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
140    let t5xxl_path = download_file_hf_hub("Green-Sky/flux.1-schnell-GGUF", "t5xxl_q3_k.gguf")?;
141
142    builder.1.t5xxl(t5xxl_path);
143    Ok(builder)
144}
145
146/// Apply <https://huggingface.co/Green-Sky/flux.1-schnell-GGUF/blob/main/t5xxl_q4_k.gguf>
147/// Default for flux_1_mini
148pub fn t5xxl_q4_k_flux_1(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
149    let t5xxl_path = download_file_hf_hub("Green-Sky/flux.1-schnell-GGUF", "t5xxl_q4_k.gguf")?;
150
151    builder.1.t5xxl(t5xxl_path);
152    Ok(builder)
153}
154
155/// Apply <https://huggingface.co/Green-Sky/flux.1-schnell-GGUF/blob/main/t5xxl_q8_0.gguf>
156pub fn t5xxl_q8_0_flux_1(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
157    let t5xxl_path = download_file_hf_hub("Green-Sky/flux.1-schnell-GGUF", "t5xxl_q8_0.gguf")?;
158
159    builder.1.t5xxl(t5xxl_path);
160    Ok(builder)
161}
162
163pub fn offload_params_to_cpu(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
164    builder.1.offload_params_to_cpu(true);
165    Ok(builder)
166}
167
168/// Apply <https://huggingface.co/kylielee505/mylcmlorassd> to reduce inference steps for SD v1 between 2-8 (default 8)
169/// cfg_scale 1. 8 steps.
170pub fn lcm_lora_ssd_1b(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
171    let lora_path = download_file_hf_hub(
172        "kylielee505/mylcmlorassd",
173        "pytorch_lora_weights.safetensors",
174    )?;
175    builder.1.lora_models(
176        lora_path.parent().unwrap(),
177        vec![LoraSpec {
178            file_name: "pytorch_lora_weights".to_string(),
179            is_high_noise: false,
180            multiplier: 1.0,
181        }],
182    );
183    builder.0.cfg_scale(1.).steps(8);
184    Ok(builder)
185}
186
187/// Enable vae tiling
188pub fn vae_tiling(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
189    builder.1.vae_tiling(true);
190    Ok(builder)
191}
192
193/// Enable preview with [crate::api::PreviewType::PREVIEW_PROJ]
194pub fn preview_proj(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
195    builder.0.preview_mode(PreviewType::PREVIEW_PROJ);
196    Ok(builder)
197}
198
199/// Enable preview with [crate::api::PreviewType::PREVIEW_TAE]
200pub fn preview_tae(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
201    builder.0.preview_mode(PreviewType::PREVIEW_TAE);
202    Ok(builder)
203}
204
205/// Enable preview with [crate::api::PreviewType::PREVIEW_VAE]
206pub fn preview_vae(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
207    builder.0.preview_mode(PreviewType::PREVIEW_VAE);
208    Ok(builder)
209}
210
211/// Enable easycache support with default values
212pub fn enable_easycache(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
213    builder.1.easy_cache(true);
214    Ok(builder)
215}
216
217/// Enable flash attention
218pub fn enable_flash_attention(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
219    builder.1.flash_attention(true);
220    Ok(builder)
221}
222
223#[cfg(test)]
224mod tests {
225    use crate::{
226        api::gen_img,
227        modifier::{
228            enable_easycache, enable_flash_attention, lcm_lora_ssd_1b, offload_params_to_cpu,
229            preview_proj, preview_tae, preview_vae, vae_tiling,
230        },
231        preset::{Flux1Weight, Modifier, Preset, PresetBuilder},
232        util::set_hf_token,
233    };
234
235    use super::{
236        hybrid_taesd, hybrid_taesd_xl, lcm_lora_sd_1_5, lcm_lora_sdxl_base_1_0, taesd, taesd_xl,
237    };
238
239    static PROMPT: &str = "a lovely dynosaur made by crochet";
240
241    fn run(preset: Preset, m: Modifier) {
242        let (mut config, mut model_config) = PresetBuilder::default()
243            .preset(preset)
244            .prompt(PROMPT)
245            .with_modifier(m)
246            .build()
247            .unwrap();
248        gen_img(&mut config, &mut model_config).unwrap();
249    }
250
251    #[ignore]
252    #[test]
253    fn test_taesd() {
254        run(Preset::StableDiffusion1_5, taesd);
255    }
256
257    #[ignore]
258    #[test]
259    fn test_taesd_xl() {
260        run(Preset::SDXLTurbo1_0Fp16, taesd_xl);
261    }
262
263    #[ignore]
264    #[test]
265    fn test_hybrid_taesd() {
266        run(Preset::StableDiffusion1_5, hybrid_taesd);
267    }
268
269    #[ignore]
270    #[test]
271    fn test_hybrid_taesd_xl() {
272        run(Preset::SDXLTurbo1_0Fp16, hybrid_taesd_xl);
273    }
274
275    #[ignore]
276    #[test]
277    fn test_lcm_lora_sd_1_5() {
278        run(Preset::StableDiffusion1_5, lcm_lora_sd_1_5);
279    }
280
281    #[ignore]
282    #[test]
283    fn test_lcm_lora_sdxl_base_1_0() {
284        run(Preset::SDXLBase1_0, lcm_lora_sdxl_base_1_0);
285    }
286
287    #[ignore]
288    #[test]
289    fn test_offload_params_to_cpu() {
290        set_hf_token(include_str!("../token.txt"));
291        run(
292            Preset::Flux1Schnell(Flux1Weight::Q2_K),
293            offload_params_to_cpu,
294        );
295    }
296
297    #[ignore]
298    #[test]
299    fn test_lcm_lora_ssd_1b() {
300        run(
301            Preset::SSD1B(crate::preset::SSD1BWeight::F8_E4M3),
302            lcm_lora_ssd_1b,
303        );
304    }
305
306    #[ignore]
307    #[test]
308    fn test_vae_tiling() {
309        run(
310            Preset::SSD1B(crate::preset::SSD1BWeight::F8_E4M3),
311            vae_tiling,
312        );
313    }
314
315    #[ignore]
316    #[test]
317    fn test_preview_proj() {
318        run(Preset::SDXLTurbo1_0Fp16, preview_proj);
319    }
320
321    #[ignore]
322    #[test]
323    fn test_preview_tae() {
324        run(Preset::SDXLTurbo1_0Fp16, preview_tae);
325    }
326
327    #[ignore]
328    #[test]
329    fn test_preview_vae() {
330        run(Preset::SDXLTurbo1_0Fp16, preview_vae);
331    }
332
333    #[ignore]
334    #[test]
335    fn test_easy_cache() {
336        set_hf_token(include_str!("../token.txt"));
337        run(
338            Preset::Flux1Mini(crate::preset::Flux1MiniWeight::Q2_K),
339            enable_easycache,
340        );
341    }
342
343    #[ignore]
344    #[test]
345    fn test_flash_attention() {
346        set_hf_token(include_str!("../token.txt"));
347        run(
348            Preset::Flux1Mini(crate::preset::Flux1MiniWeight::Q2_K),
349            enable_flash_attention,
350        );
351    }
352}