use derive_builder::Builder;
use hf_hub::api::sync::ApiError;
use strum::{EnumDiscriminants, EnumString, VariantNames};
use subenum::subenum;
use crate::{
api::{Config, ConfigBuilder, ConfigBuilderError, ModelConfig, ModelConfigBuilder},
preset_builder::{
anima, chroma, chroma_radiance, diff_instruct_star, dream_shaper_xl_2_1_turbo, flux_1_dev,
flux_1_mini, flux_1_schnell, flux_2_dev, flux_2_klein_4b, flux_2_klein_9b,
flux_2_klein_base_4b, flux_2_klein_base_9b, juggernaut_xl_11, nitro_sd_realism,
nitro_sd_vibrant, ovis_image, qwen_image, sd_turbo, sdxl_base_1_0, sdxl_turbo_1_0,
sdxs512_dream_shaper, segmind_vega, ssd_1b, stable_diffusion_1_4, stable_diffusion_1_5,
stable_diffusion_2_1, stable_diffusion_3_5_large, stable_diffusion_3_5_large_turbo,
stable_diffusion_3_5_medium, stable_diffusion_3_medium, twinflow_z_image_turbo,
z_image_turbo,
},
};
#[non_exhaustive]
#[allow(non_camel_case_types)]
#[subenum(
Flux1Weight(derive(Default)),
Flux1MiniWeight(derive(Default)),
ChromaWeight(derive(Default)),
NitroSDRealismWeight(derive(Default)),
NitroSDVibrantWeight(derive(Default)),
DiffInstructStarWeight(derive(Default)),
ChromaRadianceWeight(derive(Default)),
SSD1BWeight(derive(Default)),
Flux2Weight(derive(Default)),
ZImageTurboWeight(derive(Default)),
QwenImageWeight(derive(Default)),
OvisImageWeight(derive(Default)),
TwinFlowZImageTurboExpWeight(derive(Default)),
Flux2Klein4BWeight(derive(Default)),
Flux2KleinBase4BWeight(derive(Default)),
Flux2Klein9BWeight(derive(Default)),
Flux2KleinBase9BWeight(derive(Default)),
AnimaWeight(derive(Default))
)]
#[derive(Debug, Clone, Copy, EnumString, VariantNames)]
#[strum(ascii_case_insensitive)]
pub enum WeightType {
#[subenum(Flux1MiniWeight)]
F32,
#[subenum(
NitroSDRealismWeight,
NitroSDVibrantWeight,
DiffInstructStarWeight,
SSD1BWeight
)]
F16,
#[subenum(
Flux1Weight,
ChromaWeight(default),
NitroSDRealismWeight,
NitroSDVibrantWeight,
DiffInstructStarWeight,
Flux2Weight,
ZImageTurboWeight,
QwenImageWeight,
OvisImageWeight(default),
TwinFlowZImageTurboExpWeight(default),
Flux2Klein4BWeight,
Flux2KleinBase4BWeight,
Flux2Klein9BWeight(default),
Flux2KleinBase9BWeight(default),
AnimaWeight
)]
Q4_0,
#[subenum(Flux2Weight, QwenImageWeight, AnimaWeight)]
Q4_1,
#[subenum(
NitroSDRealismWeight,
NitroSDVibrantWeight,
DiffInstructStarWeight,
Flux2Weight,
ZImageTurboWeight,
QwenImageWeight,
TwinFlowZImageTurboExpWeight,
AnimaWeight
)]
Q5_0,
#[subenum(Flux2Weight, QwenImageWeight, AnimaWeight)]
Q5_1,
#[subenum(
Flux1Weight,
Flux1MiniWeight(default),
ChromaWeight,
NitroSDRealismWeight(default),
NitroSDVibrantWeight(default),
DiffInstructStarWeight(default),
ChromaRadianceWeight(default),
Flux2Weight,
ZImageTurboWeight,
QwenImageWeight,
OvisImageWeight,
TwinFlowZImageTurboExpWeight,
Flux2Klein4BWeight(default),
Flux2KleinBase4BWeight(default),
Flux2Klein9BWeight,
AnimaWeight(default)
)]
Q8_0,
Q8_1,
#[subenum(
Flux1Weight(default),
Flux1MiniWeight,
NitroSDRealismWeight,
NitroSDVibrantWeight,
DiffInstructStarWeight,
Flux2Weight(default),
ZImageTurboWeight,
QwenImageWeight(default)
)]
Q2_K,
#[subenum(
Flux1Weight,
Flux1MiniWeight,
NitroSDRealismWeight,
NitroSDVibrantWeight,
DiffInstructStarWeight,
ZImageTurboWeight,
Flux2Weight,
QwenImageWeight,
TwinFlowZImageTurboExpWeight,
AnimaWeight
)]
Q3_K,
#[subenum(
Flux1Weight,
ZImageTurboWeight(default),
Flux2Weight,
QwenImageWeight,
AnimaWeight
)]
Q4_K,
#[subenum(Flux1MiniWeight, Flux2Weight, QwenImageWeight, AnimaWeight)]
Q5_K,
#[subenum(
Flux1MiniWeight,
NitroSDRealismWeight,
NitroSDVibrantWeight,
DiffInstructStarWeight,
Flux2Weight,
ZImageTurboWeight,
QwenImageWeight,
TwinFlowZImageTurboExpWeight,
AnimaWeight
)]
Q6_K,
Q8_K,
IQ2_XXS,
IQ2_XS,
IQ3_XXS,
IQ1_S,
IQ4_NL,
IQ3_S,
IQ2_S,
IQ4_XS,
I8,
I16,
I32,
I64,
F64,
IQ1_M,
#[subenum(
Flux1MiniWeight,
ChromaWeight,
ChromaRadianceWeight,
Flux2Weight,
ZImageTurboWeight,
QwenImageWeight,
OvisImageWeight,
TwinFlowZImageTurboExpWeight,
Flux2Klein4BWeight,
Flux2KleinBase4BWeight,
Flux2Klein9BWeight,
Flux2KleinBase9BWeight,
AnimaWeight
)]
BF16,
TQ1_0,
TQ2_0,
MXFP4,
#[subenum(SSD1BWeight(default), QwenImageWeight)]
F8_E4M3,
}
#[non_exhaustive]
#[derive(Debug, Clone, Copy, EnumDiscriminants)]
#[strum_discriminants(derive(EnumString, VariantNames), strum(ascii_case_insensitive))]
pub enum Preset {
StableDiffusion1_4,
StableDiffusion1_5,
StableDiffusion2_1,
StableDiffusion3Medium,
StableDiffusion3_5Medium,
StableDiffusion3_5Large,
StableDiffusion3_5LargeTurbo,
SDXLBase1_0,
SDTurbo,
SDXLTurbo1_0,
Flux1Dev(Flux1Weight),
Flux1Schnell(Flux1Weight),
Flux1Mini(Flux1MiniWeight),
JuggernautXL11,
Chroma(ChromaWeight),
NitroSDRealism(NitroSDRealismWeight),
NitroSDVibrant(NitroSDVibrantWeight),
DiffInstructStar(DiffInstructStarWeight),
ChromaRadiance(ChromaRadianceWeight),
SSD1B(SSD1BWeight),
Flux2Dev(Flux2Weight),
ZImageTurbo(ZImageTurboWeight),
QwenImage(QwenImageWeight),
OvisImage(OvisImageWeight),
DreamShaperXL2_1Turbo,
TwinFlowZImageTurboExp(TwinFlowZImageTurboExpWeight),
SDXS512DreamShaper,
Flux2Klein4B(Flux2Klein4BWeight),
Flux2KleinBase4B(Flux2KleinBase4BWeight),
Flux2Klein9B(Flux2Klein9BWeight),
Flux2KleinBase9B(Flux2KleinBase9BWeight),
SegmindVega,
Anima(AnimaWeight),
}
impl Preset {
fn try_configs_builder(self) -> Result<(ConfigBuilder, ModelConfigBuilder), ApiError> {
match self {
Preset::StableDiffusion1_4 => stable_diffusion_1_4(),
Preset::StableDiffusion1_5 => stable_diffusion_1_5(),
Preset::StableDiffusion2_1 => stable_diffusion_2_1(),
Preset::StableDiffusion3Medium => stable_diffusion_3_medium(),
Preset::SDXLBase1_0 => sdxl_base_1_0(),
Preset::Flux1Dev(sd_type_t) => flux_1_dev(sd_type_t),
Preset::Flux1Schnell(sd_type_t) => flux_1_schnell(sd_type_t),
Preset::SDTurbo => sd_turbo(),
Preset::SDXLTurbo1_0 => sdxl_turbo_1_0(),
Preset::StableDiffusion3_5Large => stable_diffusion_3_5_large(),
Preset::StableDiffusion3_5Medium => stable_diffusion_3_5_medium(),
Preset::StableDiffusion3_5LargeTurbo => stable_diffusion_3_5_large_turbo(),
Preset::JuggernautXL11 => juggernaut_xl_11(),
Preset::Flux1Mini(sd_type_t) => flux_1_mini(sd_type_t),
Preset::Chroma(sd_type_t) => chroma(sd_type_t),
Preset::NitroSDRealism(sd_type_t) => nitro_sd_realism(sd_type_t),
Preset::NitroSDVibrant(sd_type_t) => nitro_sd_vibrant(sd_type_t),
Preset::DiffInstructStar(sd_type_t) => diff_instruct_star(sd_type_t),
Preset::ChromaRadiance(sd_type_t) => chroma_radiance(sd_type_t),
Preset::SSD1B(sd_type_t) => ssd_1b(sd_type_t),
Preset::Flux2Dev(sd_type_t) => flux_2_dev(sd_type_t),
Preset::ZImageTurbo(sd_type_t) => z_image_turbo(sd_type_t),
Preset::QwenImage(sd_type_t) => qwen_image(sd_type_t),
Preset::OvisImage(sd_type_t) => ovis_image(sd_type_t),
Preset::DreamShaperXL2_1Turbo => dream_shaper_xl_2_1_turbo(),
Preset::TwinFlowZImageTurboExp(sd_type_t) => twinflow_z_image_turbo(sd_type_t),
Preset::SDXS512DreamShaper => sdxs512_dream_shaper(),
Preset::Flux2Klein4B(sd_type_t) => flux_2_klein_4b(sd_type_t),
Preset::Flux2KleinBase4B(sd_type_t) => flux_2_klein_base_4b(sd_type_t),
Preset::Flux2Klein9B(sd_type_t) => flux_2_klein_9b(sd_type_t),
Preset::Flux2KleinBase9B(sd_type_t) => flux_2_klein_base_9b(sd_type_t),
Preset::SegmindVega => segmind_vega(),
Preset::Anima(sd_type_t) => anima(sd_type_t),
}
}
}
pub type ConfigsBuilder = (ConfigBuilder, ModelConfigBuilder);
pub type Configs = (Config, ModelConfig);
type ModifierFunction = dyn FnOnce(ConfigsBuilder) -> Result<ConfigsBuilder, ApiError>;
#[derive(Builder)]
#[builder(
name = "PresetBuilder",
pattern = "owned",
setter(into),
build_fn(name = "internal_build", private, error = "ConfigBuilderError")
)]
pub struct PresetConfig {
prompt: String,
preset: Preset,
#[builder(private, default = "Vec::new()")]
modifiers: Vec<Box<ModifierFunction>>,
}
impl PresetBuilder {
pub fn with_modifier<F>(mut self, f: F) -> Self
where
F: FnOnce(ConfigsBuilder) -> Result<ConfigsBuilder, ApiError> + 'static,
{
if self.modifiers.is_none() {
self.modifiers = Some(Vec::new());
}
self.modifiers.as_mut().unwrap().push(Box::new(f));
self
}
pub fn build(self) -> Result<Configs, ConfigBuilderError> {
let preset = self.internal_build()?;
let configs: ConfigsBuilder = preset
.try_into()
.map_err(|err: ApiError| ConfigBuilderError::ValidationError(err.to_string()))?;
let config = configs.0.build()?;
let config_model = configs.1.build()?;
Ok((config, config_model))
}
}
impl TryFrom<PresetConfig> for ConfigsBuilder {
type Error = ApiError;
fn try_from(value: PresetConfig) -> Result<Self, Self::Error> {
let mut configs_builder = value.preset.try_configs_builder()?;
for m in value.modifiers {
configs_builder = m(configs_builder)?;
}
configs_builder.0.prompt(value.prompt);
Ok(configs_builder)
}
}
#[cfg(test)]
mod tests {
use crate::{
api::gen_img,
preset::{
ChromaRadianceWeight, ChromaWeight, DiffInstructStarWeight, Flux1MiniWeight,
Flux1Weight, Flux2Klein4BWeight, Flux2Klein9BWeight, Flux2KleinBase4BWeight,
Flux2KleinBase9BWeight, Flux2Weight, NitroSDRealismWeight, NitroSDVibrantWeight,
OvisImageWeight, QwenImageWeight, SSD1BWeight, TwinFlowZImageTurboExpWeight,
ZImageTurboWeight,
},
util::set_hf_token,
};
use super::{Preset, PresetBuilder};
static PROMPT: &str = "a lovely dinosaur made by crochet";
fn run(preset: Preset) {
let (config, mut model_config) = PresetBuilder::default()
.preset(preset)
.prompt(PROMPT)
.build()
.unwrap();
gen_img(&config, &mut model_config).unwrap();
}
#[ignore]
#[test]
fn test_stable_diffusion_1_4() {
run(Preset::StableDiffusion1_4);
}
#[ignore]
#[test]
fn test_stable_diffusion_1_5() {
run(Preset::StableDiffusion1_5);
}
#[ignore]
#[test]
fn test_stable_diffusion_2_1() {
run(Preset::StableDiffusion2_1);
}
#[ignore]
#[test]
fn test_stable_diffusion_3_medium_fp16() {
set_hf_token(include_str!("../token.txt"));
run(Preset::StableDiffusion3Medium);
}
#[ignore]
#[test]
fn test_sdxl_base_1_0() {
run(Preset::SDXLBase1_0);
}
#[ignore]
#[test]
fn test_flux_1_dev() {
set_hf_token(include_str!("../token.txt"));
run(Preset::Flux1Dev(Flux1Weight::Q2_K));
}
#[ignore]
#[test]
fn test_flux_1_schnell() {
set_hf_token(include_str!("../token.txt"));
run(Preset::Flux1Schnell(Flux1Weight::Q2_K));
}
#[ignore]
#[test]
fn test_sd_turbo() {
run(Preset::SDTurbo);
}
#[ignore]
#[test]
fn test_sdxl_turbo_1_0_fp16() {
run(Preset::SDXLTurbo1_0);
}
#[ignore]
#[test]
fn test_stable_diffusion_3_5_medium_fp16() {
set_hf_token(include_str!("../token.txt"));
run(Preset::StableDiffusion3_5Medium);
}
#[ignore]
#[test]
fn test_stable_diffusion_3_5_large_fp16() {
set_hf_token(include_str!("../token.txt"));
run(Preset::StableDiffusion3_5Large);
}
#[ignore]
#[test]
fn test_stable_diffusion_3_5_large_turbo_fp16() {
set_hf_token(include_str!("../token.txt"));
run(Preset::StableDiffusion3_5LargeTurbo);
}
#[ignore]
#[test]
fn test_juggernaut_xl_11() {
set_hf_token(include_str!("../token.txt"));
run(Preset::JuggernautXL11);
}
#[ignore]
#[test]
fn test_flux_1_mini() {
set_hf_token(include_str!("../token.txt"));
run(Preset::Flux1Mini(Flux1MiniWeight::Q2_K));
}
#[ignore]
#[test]
fn test_chroma() {
set_hf_token(include_str!("../token.txt"));
run(Preset::Chroma(ChromaWeight::Q4_0));
}
#[ignore]
#[test]
fn test_nitro_sd_realism() {
run(Preset::NitroSDRealism(NitroSDRealismWeight::Q8_0));
}
#[ignore]
#[test]
fn test_nitro_sd_vibrant() {
run(Preset::NitroSDVibrant(NitroSDVibrantWeight::Q8_0));
}
#[ignore]
#[test]
fn test_diff_instruct_star() {
run(Preset::DiffInstructStar(DiffInstructStarWeight::Q8_0));
}
#[ignore]
#[test]
fn test_chroma_radiance() {
run(Preset::ChromaRadiance(ChromaRadianceWeight::Q8_0));
}
#[ignore]
#[test]
fn test_ssd_1b() {
run(Preset::SSD1B(SSD1BWeight::F8_E4M3));
}
#[ignore]
#[test]
fn test_flux_2_dev() {
set_hf_token(include_str!("../token.txt"));
run(Preset::Flux2Dev(Flux2Weight::Q2_K));
}
#[ignore]
#[test]
fn test_z_image_turbo() {
set_hf_token(include_str!("../token.txt"));
run(Preset::ZImageTurbo(ZImageTurboWeight::Q2_K));
}
#[ignore]
#[test]
fn test_qwen_image() {
run(Preset::QwenImage(QwenImageWeight::Q2_K));
}
#[ignore]
#[test]
fn test_ovis_image() {
set_hf_token(include_str!("../token.txt"));
run(Preset::OvisImage(OvisImageWeight::Q4_0));
}
#[ignore]
#[test]
fn test_dreamshaper_xl_2_1_turbo() {
run(Preset::DreamShaperXL2_1Turbo);
}
#[ignore]
#[test]
fn test_twinflow_z_image_turbo_exp() {
set_hf_token(include_str!("../token.txt"));
run(Preset::TwinFlowZImageTurboExp(
TwinFlowZImageTurboExpWeight::Q3_K,
));
}
#[ignore]
#[test]
fn test_sdxs512_dream_shaper() {
run(Preset::SDXS512DreamShaper);
}
#[ignore]
#[test]
fn test_flux_2_klein_4b() {
set_hf_token(include_str!("../token.txt"));
run(Preset::Flux2Klein4B(Flux2Klein4BWeight::Q8_0));
}
#[ignore]
#[test]
fn test_flux_2_klein_base_4b() {
set_hf_token(include_str!("../token.txt"));
run(Preset::Flux2KleinBase4B(Flux2KleinBase4BWeight::Q8_0));
}
#[ignore]
#[test]
fn test_flux_2_klein_9b() {
set_hf_token(include_str!("../token.txt"));
run(Preset::Flux2Klein9B(Flux2Klein9BWeight::Q4_0));
}
#[ignore]
#[test]
fn test_flux_2_klein_base_9b() {
set_hf_token(include_str!("../token.txt"));
run(Preset::Flux2KleinBase9B(Flux2KleinBase9BWeight::Q4_0));
}
#[ignore]
#[test]
fn test_segmind_vega() {
run(Preset::SegmindVega);
}
#[ignore]
#[test]
fn test_anima() {
run(Preset::Anima(super::AnimaWeight::Q8_0));
}
}