1use hf_hub::api::sync::ApiError;
2
3use crate::{api::SampleMethod, preset::ConfigsBuilder, util::download_file_hf_hub};
4
5pub fn real_esrgan_x4plus_anime_6_b(
7 mut builder: ConfigsBuilder,
8) -> Result<ConfigsBuilder, ApiError> {
9 let upscaler_path = download_file_hf_hub(
10 "ximso/RealESRGAN_x4plus_anime_6B",
11 "RealESRGAN_x4plus_anime_6B.pth",
12 )?;
13 builder.1.upscale_model(upscaler_path);
14 Ok(builder)
15}
16
17pub fn sdxl_vae_fp16_fix(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
19 let vae_path = download_file_hf_hub("madebyollin/sdxl-vae-fp16-fix", "sdxl.vae.safetensors")?;
20 builder.1.vae(vae_path);
21 Ok(builder)
22}
23
24pub fn taesd(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
26 let taesd_path =
27 download_file_hf_hub("madebyollin/taesd", "diffusion_pytorch_model.safetensors")?;
28 builder.1.taesd(taesd_path);
29 Ok(builder)
30}
31
32pub fn taesd_xl(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
34 let taesd_path =
35 download_file_hf_hub("madebyollin/taesdxl", "diffusion_pytorch_model.safetensors")?;
36 builder.1.taesd(taesd_path);
37 Ok(builder)
38}
39
40pub fn hybrid_taesd(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
42 let taesd_path = download_file_hf_hub(
43 "cqyan/hybrid-sd-tinyvae",
44 "diffusion_pytorch_model.safetensors",
45 )?;
46 builder.1.taesd(taesd_path);
47 Ok(builder)
48}
49
50pub fn hybrid_taesd_xl(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
52 let taesd_path = download_file_hf_hub(
53 "cqyan/hybrid-sd-tinyvae-xl",
54 "diffusion_pytorch_model.safetensors",
55 )?;
56 builder.1.taesd(taesd_path);
57 Ok(builder)
58}
59
60pub fn lcm_lora_sd_1_5(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
63 let lora_path = download_file_hf_hub(
64 "latent-consistency/lcm-lora-sdv1-5",
65 "pytorch_lora_weights.safetensors",
66 )?;
67 builder.1.lora_model(&lora_path);
68 builder.0.cfg_scale(1.).steps(4);
69 Ok(builder)
70}
71
72pub fn lcm_lora_sdxl_base_1_0(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
75 let lora_path = download_file_hf_hub(
76 "latent-consistency/lcm-lora-sdxl",
77 "pytorch_lora_weights.safetensors",
78 )?;
79 builder.1.lora_model(&lora_path);
80 builder
81 .0
82 .cfg_scale(2.)
83 .steps(8)
84 .sampling_method(SampleMethod::LCM);
85 Ok(builder)
86}
87
88pub fn t5xxl_fp8_flux_1(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
90 let t5xxl_path = download_file_hf_hub(
91 "comfyanonymous/flux_text_encoders",
92 "t5xxl_fp8_e4m3fn.safetensors",
93 )?;
94
95 builder.1.t5xxl(t5xxl_path);
96 Ok(builder)
97}
98
99pub fn t5xxl_fp16_flux_1(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
102 let t5xxl_path = download_file_hf_hub(
103 "comfyanonymous/flux_text_encoders",
104 "t5xxl_fp16.safetensors",
105 )?;
106
107 builder.1.t5xxl(t5xxl_path);
108 Ok(builder)
109}
110
111pub fn t5xxl_q2_k_flux_1(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
113 let t5xxl_path = download_file_hf_hub("Green-Sky/flux.1-schnell-GGUF", "t5xxl_q2_k.gguf")?;
114
115 builder.1.t5xxl(t5xxl_path);
116 Ok(builder)
117}
118
119pub fn t5xxl_q3_k_flux_1(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
121 let t5xxl_path = download_file_hf_hub("Green-Sky/flux.1-schnell-GGUF", "t5xxl_q3_k.gguf")?;
122
123 builder.1.t5xxl(t5xxl_path);
124 Ok(builder)
125}
126
127pub fn t5xxl_q4_k_flux_1(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
130 let t5xxl_path = download_file_hf_hub("Green-Sky/flux.1-schnell-GGUF", "t5xxl_q4_k.gguf")?;
131
132 builder.1.t5xxl(t5xxl_path);
133 Ok(builder)
134}
135
136pub fn t5xxl_q8_0_flux_1(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
138 let t5xxl_path = download_file_hf_hub("Green-Sky/flux.1-schnell-GGUF", "t5xxl_q8_0.gguf")?;
139
140 builder.1.t5xxl(t5xxl_path);
141 Ok(builder)
142}
143
144pub fn offload_params_to_cpu(mut builder: ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> {
145 builder.1.offload_params_to_cpu(true);
146 Ok(builder)
147}
148
149#[cfg(test)]
150mod tests {
151 use crate::{
152 api::{self, gen_img},
153 modifier::offload_params_to_cpu,
154 preset::{Modifier, Preset, PresetBuilder},
155 util::set_hf_token,
156 };
157
158 use super::{
159 hybrid_taesd, hybrid_taesd_xl, lcm_lora_sd_1_5, lcm_lora_sdxl_base_1_0, taesd, taesd_xl,
160 };
161
162 static PROMPT: &str = "a lovely cat holding a sign says 'diffusion-rs'";
163
164 fn run(preset: Preset, m: Modifier) {
165 let (mut config, mut model_config) = PresetBuilder::default()
166 .preset(preset)
167 .prompt(PROMPT)
168 .with_modifier(m)
169 .build()
170 .unwrap();
171 gen_img(&mut config, &mut model_config).unwrap();
172 }
173
174 #[ignore]
175 #[test]
176 fn test_taesd() {
177 run(Preset::StableDiffusion1_5, taesd);
178 }
179
180 #[ignore]
181 #[test]
182 fn test_taesd_xl() {
183 run(Preset::SDXLTurbo1_0Fp16, taesd_xl);
184 }
185
186 #[ignore]
187 #[test]
188 fn test_hybrid_taesd() {
189 run(Preset::StableDiffusion1_5, hybrid_taesd);
190 }
191
192 #[ignore]
193 #[test]
194 fn test_hybrid_taesd_xl() {
195 run(Preset::SDXLTurbo1_0Fp16, hybrid_taesd_xl);
196 }
197
198 #[ignore]
199 #[test]
200 fn test_lcm_lora_sd_1_5() {
201 run(Preset::StableDiffusion1_5, lcm_lora_sd_1_5);
202 }
203
204 #[ignore]
205 #[test]
206 fn test_lcm_lora_sdxl_base_1_0() {
207 run(Preset::SDXLBase1_0, lcm_lora_sdxl_base_1_0);
208 }
209
210 #[ignore]
211 #[test]
212 fn test_offload_params_to_cpu() {
213 set_hf_token(include_str!("../token.txt"));
214 run(
215 Preset::Flux1Schnell(api::WeightType::SD_TYPE_Q2_K),
216 offload_params_to_cpu,
217 );
218 }
219}