1use hf_hub::api::sync::ApiError;
2
3use crate::{
4 api::{LoraSpec, PreviewType, SampleMethod},
5 preset::ConfigsBuilder,
6 util::download_file_hf_hub,
7};
8
9pub fn real_esrgan_x4plus_anime_6_b(
11 mut builder: ConfigsBuilder,
12) -> Result<ConfigsBuilder, ApiError> {
13 let upscaler_path = download_file_hf_hub(
14 "ximso/RealESRGAN_x4plus_anime_6B",
15 "RealESRGAN_x4plus_anime_6B.pth",
16 )?;
17 builder.1.upscale_model(upscaler_path);
18 Ok(builder)
19}
20
21pub fn sdxl_vae_fp16_fix(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
23 let vae_path = download_file_hf_hub("madebyollin/sdxl-vae-fp16-fix", "sdxl.vae.safetensors")?;
24 builder.1.vae(vae_path);
25 Ok(builder)
26}
27
28pub fn taesd(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
30 let taesd_path =
31 download_file_hf_hub("madebyollin/taesd", "diffusion_pytorch_model.safetensors")?;
32 builder.1.taesd(taesd_path);
33 Ok(builder)
34}
35
36pub fn taesd_xl(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
38 let taesd_path =
39 download_file_hf_hub("madebyollin/taesdxl", "diffusion_pytorch_model.safetensors")?;
40 builder.1.taesd(taesd_path);
41 Ok(builder)
42}
43
44pub fn hybrid_taesd(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
46 let taesd_path = download_file_hf_hub(
47 "cqyan/hybrid-sd-tinyvae",
48 "diffusion_pytorch_model.safetensors",
49 )?;
50 builder.1.taesd(taesd_path);
51 Ok(builder)
52}
53
54pub fn hybrid_taesd_xl(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
56 let taesd_path = download_file_hf_hub(
57 "cqyan/hybrid-sd-tinyvae-xl",
58 "diffusion_pytorch_model.safetensors",
59 )?;
60 builder.1.taesd(taesd_path);
61 Ok(builder)
62}
63
64pub fn lcm_lora_sd_1_5(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
67 let lora_path = download_file_hf_hub(
68 "latent-consistency/lcm-lora-sdv1-5",
69 "pytorch_lora_weights.safetensors",
70 )?;
71 builder.1.lora_models(
72 lora_path.parent().unwrap(),
73 vec![LoraSpec {
74 file_name: "pytorch_lora_weights".to_string(),
75 is_high_noise: false,
76 multiplier: 1.0,
77 }],
78 );
79 builder.0.cfg_scale(1.).steps(8);
80 Ok(builder)
81}
82
83pub fn lcm_lora_sdxl_base_1_0(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
86 let lora_path = download_file_hf_hub(
87 "latent-consistency/lcm-lora-sdxl",
88 "pytorch_lora_weights.safetensors",
89 )?;
90
91 builder.1.lora_models(
92 lora_path.parent().unwrap(),
93 vec![LoraSpec {
94 file_name: "pytorch_lora_weights".to_string(),
95 is_high_noise: false,
96 multiplier: 1.0,
97 }],
98 );
99 builder
100 .0
101 .cfg_scale(2.)
102 .steps(8)
103 .sampling_method(SampleMethod::LCM_SAMPLE_METHOD);
104 Ok(builder)
105}
106
107pub fn t5xxl_fp8_flux_1(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
109 let t5xxl_path = download_file_hf_hub(
110 "comfyanonymous/flux_text_encoders",
111 "t5xxl_fp8_e4m3fn.safetensors",
112 )?;
113
114 builder.1.t5xxl(t5xxl_path);
115 Ok(builder)
116}
117
118pub fn t5xxl_fp16_flux_1(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
121 let t5xxl_path = download_file_hf_hub(
122 "comfyanonymous/flux_text_encoders",
123 "t5xxl_fp16.safetensors",
124 )?;
125
126 builder.1.t5xxl(t5xxl_path);
127 Ok(builder)
128}
129
130pub fn t5xxl_q2_k_flux_1(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
132 let t5xxl_path = download_file_hf_hub("Green-Sky/flux.1-schnell-GGUF", "t5xxl_q2_k.gguf")?;
133
134 builder.1.t5xxl(t5xxl_path);
135 Ok(builder)
136}
137
138pub fn t5xxl_q3_k_flux_1(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
140 let t5xxl_path = download_file_hf_hub("Green-Sky/flux.1-schnell-GGUF", "t5xxl_q3_k.gguf")?;
141
142 builder.1.t5xxl(t5xxl_path);
143 Ok(builder)
144}
145
146pub fn t5xxl_q4_k_flux_1(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
149 let t5xxl_path = download_file_hf_hub("Green-Sky/flux.1-schnell-GGUF", "t5xxl_q4_k.gguf")?;
150
151 builder.1.t5xxl(t5xxl_path);
152 Ok(builder)
153}
154
155pub fn t5xxl_q8_0_flux_1(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
157 let t5xxl_path = download_file_hf_hub("Green-Sky/flux.1-schnell-GGUF", "t5xxl_q8_0.gguf")?;
158
159 builder.1.t5xxl(t5xxl_path);
160 Ok(builder)
161}
162
163pub fn offload_params_to_cpu(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
165 builder.1.offload_params_to_cpu(true);
166 Ok(builder)
167}
168
169pub fn lcm_lora_ssd_1b(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
172 let lora_path = download_file_hf_hub(
173 "kylielee505/mylcmlorassd",
174 "pytorch_lora_weights.safetensors",
175 )?;
176 builder.1.lora_models(
177 lora_path.parent().unwrap(),
178 vec![LoraSpec {
179 file_name: "pytorch_lora_weights".to_string(),
180 is_high_noise: false,
181 multiplier: 1.0,
182 }],
183 );
184 builder.0.cfg_scale(1.).steps(8);
185 Ok(builder)
186}
187
188pub fn vae_tiling(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
190 builder.1.vae_tiling(true);
191 Ok(builder)
192}
193
194pub fn preview_proj(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
196 builder.0.preview_mode(PreviewType::PREVIEW_PROJ);
197 Ok(builder)
198}
199
200pub fn preview_tae(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
202 builder.0.preview_mode(PreviewType::PREVIEW_TAE);
203 Ok(builder)
204}
205
206pub fn preview_vae(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
208 builder.0.preview_mode(PreviewType::PREVIEW_VAE);
209 Ok(builder)
210}
211
212pub fn enable_flash_attention(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
214 builder.1.flash_attention(true);
215 Ok(builder)
216}
217
218pub fn segmind_vega_rt_lcm_lora(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
220 let lora_path =
221 download_file_hf_hub("segmind/Segmind-VegaRT", "pytorch_lora_weights.safetensors")?;
222 builder.1.lora_models(
223 lora_path.parent().unwrap(),
224 vec![LoraSpec {
225 file_name: "pytorch_lora_weights".to_string(),
226 is_high_noise: false,
227 multiplier: 1.0,
228 }],
229 );
230 builder.0.guidance(0.).steps(4);
231 Ok(builder)
232}
233
234#[cfg(test)]
235mod tests {
236 use hf_hub::api::sync::ApiError;
237
238 use crate::{
239 api::gen_img,
240 modifier::{
241 enable_flash_attention, lcm_lora_ssd_1b, offload_params_to_cpu, preview_proj,
242 preview_tae, preview_vae, segmind_vega_rt_lcm_lora, vae_tiling,
243 },
244 preset::{ConfigsBuilder, Flux1Weight, Preset, PresetBuilder},
245 util::set_hf_token,
246 };
247
248 use super::{
249 hybrid_taesd, hybrid_taesd_xl, lcm_lora_sd_1_5, lcm_lora_sdxl_base_1_0, taesd, taesd_xl,
250 };
251
252 static PROMPT: &str = "a lovely dinosaur made by crochet";
253
254 fn run<F>(preset: Preset, m: F)
255 where
256 F: FnOnce(ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> + 'static,
257 {
258 let (mut config, mut model_config) = PresetBuilder::default()
259 .preset(preset)
260 .prompt(PROMPT)
261 .with_modifier(m)
262 .build()
263 .unwrap();
264 gen_img(&mut config, &mut model_config).unwrap();
265 }
266
267 #[ignore]
268 #[test]
269 fn test_taesd() {
270 run(Preset::StableDiffusion1_5, taesd);
271 }
272
273 #[ignore]
274 #[test]
275 fn test_taesd_xl() {
276 run(Preset::SDXLTurbo1_0, taesd_xl);
277 }
278
279 #[ignore]
280 #[test]
281 fn test_hybrid_taesd() {
282 run(Preset::StableDiffusion1_5, hybrid_taesd);
283 }
284
285 #[ignore]
286 #[test]
287 fn test_hybrid_taesd_xl() {
288 run(Preset::SDXLTurbo1_0, hybrid_taesd_xl);
289 }
290
291 #[ignore]
292 #[test]
293 fn test_lcm_lora_sd_1_5() {
294 run(Preset::StableDiffusion1_5, lcm_lora_sd_1_5);
295 }
296
297 #[ignore]
298 #[test]
299 fn test_lcm_lora_sdxl_base_1_0() {
300 run(Preset::SDXLBase1_0, lcm_lora_sdxl_base_1_0);
301 }
302
303 #[ignore]
304 #[test]
305 fn test_offload_params_to_cpu() {
306 set_hf_token(include_str!("../token.txt"));
307 run(
308 Preset::Flux1Schnell(Flux1Weight::Q2_K),
309 offload_params_to_cpu,
310 );
311 }
312
313 #[ignore]
314 #[test]
315 fn test_lcm_lora_ssd_1b() {
316 run(
317 Preset::SSD1B(crate::preset::SSD1BWeight::F8_E4M3),
318 lcm_lora_ssd_1b,
319 );
320 }
321
322 #[ignore]
323 #[test]
324 fn test_vae_tiling() {
325 run(
326 Preset::SSD1B(crate::preset::SSD1BWeight::F8_E4M3),
327 vae_tiling,
328 );
329 }
330
331 #[ignore]
332 #[test]
333 fn test_preview_proj() {
334 run(Preset::SDXLTurbo1_0, preview_proj);
335 }
336
337 #[ignore]
338 #[test]
339 fn test_preview_tae() {
340 run(Preset::SDXLTurbo1_0, preview_tae);
341 }
342
343 #[ignore]
344 #[test]
345 fn test_preview_vae() {
346 run(Preset::SDXLTurbo1_0, preview_vae);
347 }
348
349 #[ignore]
350 #[test]
351 fn test_flash_attention() {
352 set_hf_token(include_str!("../token.txt"));
353 run(
354 Preset::Flux1Mini(crate::preset::Flux1MiniWeight::Q2_K),
355 enable_flash_attention,
356 );
357 }
358
359 #[ignore]
360 #[test]
361 fn test_segmind_vega_rt_lcm_lora() {
362 run(Preset::SegmindVega, segmind_vega_rt_lcm_lora);
363 }
364}