use hf_hub::api::sync::ApiError;
use crate::{
api::{ConfigBuilder, SampleMethod},
util::download_file_hf_hub,
};
pub fn real_esrgan_x4plus_anime_6_b(mut builder: ConfigBuilder) -> Result<ConfigBuilder, ApiError> {
let upscaler_path = download_file_hf_hub(
"ximso/RealESRGAN_x4plus_anime_6B",
"RealESRGAN_x4plus_anime_6B.pth",
)?;
builder.upscale_model(upscaler_path);
Ok(builder)
}
pub fn sdxl_vae_fp16_fix(mut builder: ConfigBuilder) -> Result<ConfigBuilder, ApiError> {
let vae_path = download_file_hf_hub("madebyollin/sdxl-vae-fp16-fix", "sdxl.vae.safetensors")?;
builder.vae(vae_path);
Ok(builder)
}
pub fn taesd(mut builder: ConfigBuilder) -> Result<ConfigBuilder, ApiError> {
let taesd_path =
download_file_hf_hub("madebyollin/taesd", "diffusion_pytorch_model.safetensors")?;
builder.taesd(taesd_path);
Ok(builder)
}
pub fn taesd_xl(mut builder: ConfigBuilder) -> Result<ConfigBuilder, ApiError> {
let taesd_path =
download_file_hf_hub("madebyollin/taesdxl", "diffusion_pytorch_model.safetensors")?;
builder.taesd(taesd_path);
Ok(builder)
}
pub fn hybrid_taesd(mut builder: ConfigBuilder) -> Result<ConfigBuilder, ApiError> {
let taesd_path = download_file_hf_hub(
"cqyan/hybrid-sd-tinyvae",
"diffusion_pytorch_model.safetensors",
)?;
builder.taesd(taesd_path);
Ok(builder)
}
pub fn hybrid_taesd_xl(mut builder: ConfigBuilder) -> Result<ConfigBuilder, ApiError> {
let taesd_path = download_file_hf_hub(
"cqyan/hybrid-sd-tinyvae-xl",
"diffusion_pytorch_model.safetensors",
)?;
builder.taesd(taesd_path);
Ok(builder)
}
pub fn lcm_lora_sd_1_5(mut builder: ConfigBuilder) -> Result<ConfigBuilder, ApiError> {
let lora_path = download_file_hf_hub(
"latent-consistency/lcm-lora-sdv1-5",
"pytorch_lora_weights.safetensors",
)?;
builder.lora_model(&lora_path).cfg_scale(1.).steps(4);
Ok(builder)
}
pub fn lcm_lora_sdxl_base_1_0(mut builder: ConfigBuilder) -> Result<ConfigBuilder, ApiError> {
let lora_path = download_file_hf_hub(
"latent-consistency/lcm-lora-sdxl",
"pytorch_lora_weights.safetensors",
)?;
builder
.lora_model(&lora_path)
.cfg_scale(2.)
.steps(8)
.sampling_method(SampleMethod::LCM);
Ok(builder)
}
#[cfg(test)]
mod tests {
use crate::{
api::txt2img,
preset::{Modifier, Preset, PresetBuilder},
};
use super::{
hybrid_taesd, hybrid_taesd_xl, lcm_lora_sd_1_5, lcm_lora_sdxl_base_1_0, taesd, taesd_xl,
};
static PROMPT: &str = "a lovely duck drinking water from a bottle";
fn run(preset: Preset, m: Modifier) {
let config = PresetBuilder::default()
.preset(preset)
.prompt(PROMPT)
.with_modifier(m)
.build()
.unwrap();
txt2img(config).unwrap();
}
#[ignore]
#[test]
fn test_taesd() {
run(Preset::StableDiffusion1_5, taesd);
}
#[ignore]
#[test]
fn test_taesd_xl() {
run(Preset::SDXLTurbo1_0Fp16, taesd_xl);
}
#[ignore]
#[test]
fn test_hybrid_taesd() {
run(Preset::StableDiffusion1_5, hybrid_taesd);
}
#[ignore]
#[test]
fn test_hybrid_taesd_xl() {
run(Preset::SDXLTurbo1_0Fp16, hybrid_taesd_xl);
}
#[ignore]
#[test]
fn test_lcm_lora_sd_1_5() {
run(Preset::StableDiffusion1_5, lcm_lora_sd_1_5);
}
#[ignore]
#[test]
fn test_lcm_lora_sdxl_base_1_0() {
run(Preset::SDXLBase1_0, lcm_lora_sdxl_base_1_0);
}
}