use std::cmp::max;
use std::collections::HashMap;
use std::ffi::CString;
use std::ffi::c_char;
use std::ffi::c_void;
use std::fmt::Display;
use std::path::Path;
use std::path::PathBuf;
use std::ptr::null;
use std::ptr::null_mut;
use std::slice;
use std::sync::mpsc::Sender;
use chrono::Local;
use derive_builder::Builder;
use diffusion_rs_sys::free_upscaler_ctx;
use diffusion_rs_sys::generate_image;
use diffusion_rs_sys::new_upscaler_ctx;
use diffusion_rs_sys::sd_cache_mode_t;
use diffusion_rs_sys::sd_cache_params_t;
use diffusion_rs_sys::sd_ctx_params_t;
use diffusion_rs_sys::sd_embedding_t;
use diffusion_rs_sys::sd_get_default_sample_method;
use diffusion_rs_sys::sd_get_default_scheduler;
use diffusion_rs_sys::sd_guidance_params_t;
use diffusion_rs_sys::sd_image_t;
use diffusion_rs_sys::sd_img_gen_params_t;
use diffusion_rs_sys::sd_img_gen_params_to_str;
use diffusion_rs_sys::sd_lora_t;
use diffusion_rs_sys::sd_pm_params_t;
use diffusion_rs_sys::sd_sample_params_t;
use diffusion_rs_sys::sd_set_preview_callback;
use diffusion_rs_sys::sd_set_progress_callback;
use diffusion_rs_sys::sd_slg_params_t;
use diffusion_rs_sys::sd_tiling_params_t;
use diffusion_rs_sys::upscaler_ctx_t;
use image::ImageBuffer;
use image::ImageError;
use image::RgbImage;
use libc::free;
use little_exif::exif_tag::ExifTag;
use little_exif::metadata::Metadata;
use thiserror::Error;
use walkdir::DirEntry;
use walkdir::WalkDir;
use diffusion_rs_sys::free_sd_ctx;
use diffusion_rs_sys::new_sd_ctx;
use diffusion_rs_sys::sd_ctx_t;
pub use diffusion_rs_sys::rng_type_t as RngFunction;
pub use diffusion_rs_sys::sample_method_t as SampleMethod;
pub use diffusion_rs_sys::scheduler_t as Scheduler;
pub use diffusion_rs_sys::prediction_t as Prediction;
pub use diffusion_rs_sys::sd_type_t as WeightType;
pub use diffusion_rs_sys::preview_t as PreviewType;
pub use diffusion_rs_sys::lora_apply_mode_t as LoraModeType;
static VALID_EXT: [&str; 3] = ["gguf", "safetensors", "pt"];
#[allow(unused)]
#[derive(Debug)]
pub struct Progress {
step: i32,
steps: i32,
time: f32,
}
#[non_exhaustive]
#[derive(Error, Debug)]
pub enum DiffusionError {
#[error("The underling stablediffusion.cpp function returned NULL")]
Forward,
#[error(transparent)]
StoreImages(#[from] ImageError),
#[error(transparent)]
Io(#[from] std::io::Error),
#[error("The underling upscaler model returned a NULL image")]
Upscaler,
}
#[repr(i32)]
#[non_exhaustive]
#[derive(Debug, Default, Copy, Clone, Hash, PartialEq, Eq)]
pub enum ClipSkip {
#[default]
Unspecified = 0,
None = 1,
OneLayer = 2,
}
type EmbeddingsStorage = (PathBuf, Vec<(CLibString, CLibPath)>, Vec<sd_embedding_t>);
type LoraStorage = Vec<(CLibPath, LoraSpec)>;
#[derive(Default, Debug, Clone)]
pub struct LoraSpec {
pub file_name: String,
pub is_high_noise: bool,
pub multiplier: f32,
}
#[derive(Builder, Debug, Clone)]
pub struct UCacheParams {
#[builder(default = "1.0")]
threshold: f32,
#[builder(default = "0.15")]
start: f32,
#[builder(default = "0.95")]
end: f32,
#[builder(default = "1.0")]
decay: f32,
#[builder(default = "true")]
relative: bool,
#[builder(default = "true")]
reset: bool,
}
#[derive(Builder, Debug, Clone)]
pub struct EasyCacheParams {
#[builder(default = "0.2")]
threshold: f32,
#[builder(default = "0.15")]
start: f32,
#[builder(default = "0.95")]
end: f32,
}
#[derive(Builder, Debug, Clone)]
pub struct DbCacheParams {
#[builder(default = "8")]
fn_blocks: i32,
#[builder(default = "0")]
bn_blocks: i32,
#[builder(default = "0.08")]
threshold: f32,
#[builder(default = "8")]
warmup: i32,
scm_mask: ScmPreset,
#[builder(default = "ScmPolicy::default()")]
scm_policy_dynamic: ScmPolicy,
}
#[derive(Debug, Default, Clone)]
pub enum ScmPolicy {
Static,
#[default]
Dynamic,
}
#[derive(Debug, Default, Clone)]
pub enum ScmPreset {
Slow,
#[default]
Medium,
Fast,
Ultra,
Custom(String),
}
impl ScmPreset {
fn to_vec_string(&self, steps: i32) -> String {
match self {
ScmPreset::Slow => ScmPresetBins {
compute_bins: vec![8, 3, 3, 2, 1, 1],
cache_bins: vec![1, 2, 2, 2, 3],
steps,
}
.to_string(),
ScmPreset::Medium => ScmPresetBins {
compute_bins: vec![6, 2, 2, 2, 2, 1],
cache_bins: vec![1, 3, 3, 3, 3],
steps,
}
.to_string(),
ScmPreset::Fast => ScmPresetBins {
compute_bins: vec![6, 1, 1, 1, 1, 1],
cache_bins: vec![1, 3, 4, 5, 4],
steps,
}
.to_string(),
ScmPreset::Ultra => ScmPresetBins {
compute_bins: vec![4, 1, 1, 1, 1],
cache_bins: vec![2, 5, 6, 7],
steps,
}
.to_string(),
ScmPreset::Custom(s) => s.clone(),
}
}
}
#[derive(Debug, Clone)]
struct ScmPresetBins {
compute_bins: Vec<i32>,
cache_bins: Vec<i32>,
steps: i32,
}
impl ScmPresetBins {
fn maybe_scale(&self) -> ScmPresetBins {
if self.steps == 28 || self.steps <= 0 {
return self.clone();
}
self.scale()
}
fn scale(&self) -> ScmPresetBins {
let scale = self.steps as f32 / 28.0;
let scaled_compute_bins = self
.compute_bins
.iter()
.map(|b| max(1, (*b as f32 * scale * 0.5) as i32))
.collect();
let scaled_cached_bins = self
.cache_bins
.iter()
.map(|b| max(1, (*b as f32 * scale * 0.5) as i32))
.collect();
ScmPresetBins {
compute_bins: scaled_compute_bins,
cache_bins: scaled_cached_bins,
steps: self.steps,
}
}
fn generate_vec_mask(&self) -> Vec<i32> {
let mut mask = Vec::new();
let mut c_idx = 0;
let mut cache_idx = 0;
while mask.len() < self.steps as usize {
if c_idx < self.compute_bins.len() {
let compute_count = self.compute_bins[c_idx];
for _ in 0..compute_count {
if mask.len() < self.steps as usize {
mask.push(1);
}
}
c_idx += 1;
}
if cache_idx < self.cache_bins.len() {
let cache_count = self.cache_bins[c_idx];
for _ in 0..cache_count {
if mask.len() < self.steps as usize {
mask.push(0);
}
}
cache_idx += 1;
}
if c_idx >= self.compute_bins.len() && cache_idx >= self.cache_bins.len() {
break;
}
}
if let Some(last) = mask.last_mut() {
*last = 1;
}
mask
}
}
impl Display for ScmPresetBins {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mask: String = self
.maybe_scale()
.generate_vec_mask()
.iter()
.map(|x| x.to_string())
.collect::<Vec<_>>()
.join(",");
write!(f, "{mask}")
}
}
#[derive(Builder, Debug, Clone)]
#[builder(
setter(into, strip_option),
build_fn(error = "ConfigBuilderError", validate = "Self::validate")
)]
pub struct ModelConfig {
#[builder(default = "num_cpus::get_physical() as i32", setter(custom))]
n_threads: i32,
#[builder(default = "false")]
enable_mmap: bool,
#[builder(default = "false")]
offload_params_to_cpu: bool,
#[builder(default = "Default::default()")]
upscale_model: Option<CLibPath>,
#[builder(default = "1")]
upscale_repeats: i32,
#[builder(default = "128")]
upscale_tile_size: i32,
#[builder(default = "Default::default()")]
model: CLibPath,
#[builder(default = "Default::default()")]
diffusion_model: CLibPath,
#[builder(default = "Default::default()")]
llm: CLibPath,
#[builder(default = "Default::default()")]
llm_vision: CLibPath,
#[builder(default = "Default::default()")]
clip_l: CLibPath,
#[builder(default = "Default::default()")]
clip_g: CLibPath,
#[builder(default = "Default::default()")]
clip_vision: CLibPath,
#[builder(default = "Default::default()")]
t5xxl: CLibPath,
#[builder(default = "Default::default()")]
vae: CLibPath,
#[builder(default = "Default::default()")]
taesd: CLibPath,
#[builder(default = "Default::default()")]
control_net: CLibPath,
#[builder(default = "Default::default()", setter(custom))]
embeddings: EmbeddingsStorage,
#[builder(default = "Default::default()")]
photo_maker: CLibPath,
#[builder(default = "Default::default()")]
pm_id_embed_path: CLibPath,
#[builder(default = "WeightType::SD_TYPE_COUNT")]
weight_type: WeightType,
#[builder(default = "Default::default()", setter(custom))]
lora_models: LoraStorage,
#[builder(default = "Default::default()")]
high_noise_diffusion_model: CLibPath,
#[builder(default = "false")]
vae_tiling: bool,
#[builder(default = "(32,32)")]
vae_tile_size: (i32, i32),
#[builder(default = "(0.,0.)")]
vae_relative_tile_size: (f32, f32),
#[builder(default = "0.5")]
vae_tile_overlap: f32,
#[builder(default = "RngFunction::CUDA_RNG")]
rng: RngFunction,
#[builder(default = "RngFunction::RNG_TYPE_COUNT")]
sampler_rng_type: RngFunction,
#[builder(default = "Scheduler::SCHEDULER_COUNT")]
scheduler: Scheduler,
#[builder(default = "Default::default()")]
sigmas: Vec<f32>,
#[builder(default = "Prediction::PREDICTION_COUNT")]
prediction: Prediction,
#[builder(default = "false")]
vae_on_cpu: bool,
#[builder(default = "false")]
clip_on_cpu: bool,
#[builder(default = "false")]
control_net_cpu: bool,
#[builder(default = "false")]
diffusion_flash_attention: bool,
#[builder(default = "false")]
flash_attention: bool,
#[builder(default = "false")]
chroma_disable_dit_mask: bool,
#[builder(default = "false")]
chroma_enable_t5_mask: bool,
#[builder(default = "1")]
chroma_t5_mask_pad: i32,
#[builder(default = "false")]
use_qwen_image_zero_cond_true: bool,
#[builder(default = "false")]
diffusion_conv_direct: bool,
#[builder(default = "false")]
vae_conv_direct: bool,
#[builder(default = "false")]
force_sdxl_vae_conv_scale: bool,
#[builder(default = "f32::INFINITY")]
flow_shift: f32,
#[builder(default = "0")]
timestep_shift: i32,
#[builder(default = "false")]
taesd_preview_only: bool,
#[builder(default = "LoraModeType::LORA_APPLY_AUTO")]
lora_apply_mode: LoraModeType,
#[builder(default = "false")]
circular: bool,
#[builder(default = "false")]
circular_x: bool,
#[builder(default = "false")]
circular_y: bool,
#[builder(default = "None", private)]
upscaler_ctx: Option<*mut upscaler_ctx_t>,
#[builder(default = "None", private)]
diffusion_ctx: Option<(*mut sd_ctx_t, sd_ctx_params_t)>,
}
impl ModelConfigBuilder {
fn validate(&self) -> Result<(), ConfigBuilderError> {
self.validate_model()
}
fn validate_model(&self) -> Result<(), ConfigBuilderError> {
self.model
.as_ref()
.or(self.diffusion_model.as_ref())
.map(|_| ())
.ok_or(ConfigBuilderError::UninitializedField(
"Model OR DiffusionModel must be valorized",
))
}
fn filter_valid_extensions(path: &Path) -> impl Iterator<Item = DirEntry> {
WalkDir::new(path)
.into_iter()
.filter_map(|entry| entry.ok())
.filter(|entry| {
entry
.path()
.extension()
.and_then(|ext| ext.to_str())
.map(|ext_str| VALID_EXT.contains(&ext_str))
.unwrap_or(false)
})
}
fn build_single_lora_storage(
spec: &LoraSpec,
valid_loras: &HashMap<String, PathBuf>,
) -> (CLibPath, LoraSpec) {
let path = valid_loras.get(&spec.file_name).unwrap().as_path();
let c_path = CLibPath::from(path);
(c_path, spec.clone())
}
pub fn embeddings(&mut self, embeddings_dir: &Path) -> &mut Self {
let data: Vec<(CLibString, CLibPath)> = Self::filter_valid_extensions(embeddings_dir)
.map(|entry| {
let file_stem = entry
.path()
.file_stem()
.and_then(|stem| stem.to_str())
.unwrap_or_default()
.to_owned();
(CLibString::from(file_stem), CLibPath::from(entry.path()))
})
.collect();
let data_pointer = data
.iter()
.map(|(name, path)| sd_embedding_t {
name: name.as_ptr(),
path: path.as_ptr(),
})
.collect();
self.embeddings = Some((embeddings_dir.to_path_buf(), data, data_pointer));
self
}
pub fn lora_models(&mut self, lora_model_dir: &Path, specs: Vec<LoraSpec>) -> &mut Self {
let valid_loras: HashMap<String, PathBuf> = Self::filter_valid_extensions(lora_model_dir)
.map(|entry| {
let path = entry.path();
(
path.file_stem()
.and_then(|stem| stem.to_str())
.unwrap_or_default()
.to_owned(),
path.to_path_buf(),
)
})
.collect();
let valid_lora_names: Vec<&String> = valid_loras.keys().collect();
let standard = specs
.iter()
.filter(|s| valid_lora_names.contains(&&s.file_name) && !s.is_high_noise)
.map(|s| Self::build_single_lora_storage(s, &valid_loras));
let high_noise = specs
.iter()
.filter(|s| valid_lora_names.contains(&&s.file_name) && s.is_high_noise)
.map(|s| Self::build_single_lora_storage(s, &valid_loras));
self.lora_models_internal(standard.chain(high_noise).collect())
}
fn lora_models_internal(&mut self, lora_storage: LoraStorage) -> &mut Self {
self.lora_models = Some(lora_storage);
self
}
pub fn n_threads(&mut self, value: i32) -> &mut Self {
self.n_threads = if value > 0 {
Some(value)
} else {
Some(num_cpus::get_physical() as i32)
};
self
}
}
impl ModelConfig {
unsafe fn upscaler_ctx(&mut self) -> Option<*mut upscaler_ctx_t> {
unsafe {
if self.upscale_model.is_none() || self.upscale_repeats == 0 {
None
} else {
if self.upscaler_ctx.is_none() {
let upscaler = new_upscaler_ctx(
self.upscale_model.as_ref().unwrap().as_ptr(),
self.offload_params_to_cpu,
self.diffusion_conv_direct,
self.n_threads,
self.upscale_tile_size,
);
self.upscaler_ctx = Some(upscaler);
}
self.upscaler_ctx
}
}
}
unsafe fn diffusion_ctx(&mut self, vae_decode_only: bool) -> *mut sd_ctx_t {
unsafe {
if let Some((sd_ctx, sd_ctx_params)) = self.diffusion_ctx.as_ref()
&& sd_ctx_params.vae_decode_only != vae_decode_only
{
sd_set_progress_callback(None, null_mut());
free_sd_ctx(*sd_ctx);
self.diffusion_ctx = None;
}
if self.diffusion_ctx.is_none() {
let sd_ctx_params = sd_ctx_params_t {
model_path: self.model.as_ptr(),
llm_path: self.llm.as_ptr(),
llm_vision_path: self.llm_vision.as_ptr(),
clip_l_path: self.clip_l.as_ptr(),
clip_g_path: self.clip_g.as_ptr(),
clip_vision_path: self.clip_vision.as_ptr(),
high_noise_diffusion_model_path: self.high_noise_diffusion_model.as_ptr(),
t5xxl_path: self.t5xxl.as_ptr(),
diffusion_model_path: self.diffusion_model.as_ptr(),
vae_path: self.vae.as_ptr(),
taesd_path: self.taesd.as_ptr(),
control_net_path: self.control_net.as_ptr(),
embeddings: self.embeddings.2.as_ptr(),
embedding_count: self.embeddings.1.len() as u32,
photo_maker_path: self.photo_maker.as_ptr(),
vae_decode_only,
free_params_immediately: false,
n_threads: self.n_threads,
wtype: self.weight_type,
rng_type: self.rng,
keep_clip_on_cpu: self.clip_on_cpu,
keep_control_net_on_cpu: self.control_net_cpu,
keep_vae_on_cpu: self.vae_on_cpu,
diffusion_flash_attn: self.diffusion_flash_attention,
flash_attn: self.flash_attention,
diffusion_conv_direct: self.diffusion_conv_direct,
chroma_use_dit_mask: !self.chroma_disable_dit_mask,
chroma_use_t5_mask: self.chroma_enable_t5_mask,
chroma_t5_mask_pad: self.chroma_t5_mask_pad,
vae_conv_direct: self.vae_conv_direct,
offload_params_to_cpu: self.offload_params_to_cpu,
prediction: self.prediction,
force_sdxl_vae_conv_scale: self.force_sdxl_vae_conv_scale,
tae_preview_only: self.taesd_preview_only,
lora_apply_mode: self.lora_apply_mode,
tensor_type_rules: null_mut(),
sampler_rng_type: self.sampler_rng_type,
circular_x: self.circular || self.circular_x,
circular_y: self.circular || self.circular_y,
qwen_image_zero_cond_t: self.use_qwen_image_zero_cond_true,
enable_mmap: self.enable_mmap,
};
let ctx = new_sd_ctx(&sd_ctx_params);
self.diffusion_ctx = Some((ctx, sd_ctx_params))
}
self.diffusion_ctx.unwrap().0
}
}
}
impl Drop for ModelConfig {
fn drop(&mut self) {
unsafe {
if let Some((sd_ctx, _)) = self.diffusion_ctx {
free_sd_ctx(sd_ctx);
}
if let Some(upscaler_ctx) = self.upscaler_ctx {
free_upscaler_ctx(upscaler_ctx);
}
}
}
}
impl From<ModelConfig> for ModelConfigBuilder {
fn from(value: ModelConfig) -> Self {
let mut builder = ModelConfigBuilder::default();
builder
.n_threads(value.n_threads)
.offload_params_to_cpu(value.offload_params_to_cpu)
.upscale_repeats(value.upscale_repeats)
.model(value.model.clone())
.diffusion_model(value.diffusion_model.clone())
.llm(value.llm.clone())
.llm_vision(value.llm_vision.clone())
.clip_l(value.clip_l.clone())
.clip_g(value.clip_g.clone())
.clip_vision(value.clip_vision.clone())
.t5xxl(value.t5xxl.clone())
.vae(value.vae.clone())
.taesd(value.taesd.clone())
.control_net(value.control_net.clone())
.embeddings(&value.embeddings.0)
.photo_maker(value.photo_maker.clone())
.pm_id_embed_path(value.pm_id_embed_path.clone())
.weight_type(value.weight_type)
.high_noise_diffusion_model(value.high_noise_diffusion_model.clone())
.vae_tiling(value.vae_tiling)
.vae_tile_size(value.vae_tile_size)
.vae_relative_tile_size(value.vae_relative_tile_size)
.vae_tile_overlap(value.vae_tile_overlap)
.rng(value.rng)
.sampler_rng_type(value.rng)
.scheduler(value.scheduler)
.sigmas(value.sigmas.clone())
.prediction(value.prediction)
.vae_on_cpu(value.vae_on_cpu)
.clip_on_cpu(value.clip_on_cpu)
.control_net(value.control_net.clone())
.control_net_cpu(value.control_net_cpu)
.flash_attention(value.flash_attention)
.chroma_disable_dit_mask(value.chroma_disable_dit_mask)
.chroma_enable_t5_mask(value.chroma_enable_t5_mask)
.chroma_t5_mask_pad(value.chroma_t5_mask_pad)
.diffusion_conv_direct(value.diffusion_conv_direct)
.vae_conv_direct(value.vae_conv_direct)
.force_sdxl_vae_conv_scale(value.force_sdxl_vae_conv_scale)
.flow_shift(value.flow_shift)
.timestep_shift(value.timestep_shift)
.taesd_preview_only(value.taesd_preview_only)
.lora_apply_mode(value.lora_apply_mode)
.circular(value.circular)
.circular_x(value.circular_x)
.circular_y(value.circular_y)
.use_qwen_image_zero_cond_true(value.use_qwen_image_zero_cond_true);
builder.lora_models_internal(value.lora_models.clone());
if let Some(model) = &value.upscale_model {
builder.upscale_model(model.clone());
}
builder
}
}
#[derive(Builder, Debug, Clone)]
#[builder(setter(into, strip_option), build_fn(validate = "Self::validate"))]
pub struct Config {
#[builder(default = "Default::default()")]
pm_id_images_dir: CLibPath,
#[builder(default = "Default::default()")]
init_img: PathBuf,
#[builder(default = "Default::default()")]
mask_img: PathBuf,
#[builder(default = "Default::default()")]
control_image: CLibPath,
#[builder(default = "Default::default()")]
ref_images: Vec<PathBuf>,
#[builder(default = "PathBuf::from(\"./output.png\")")]
output: PathBuf,
#[builder(default = "PathBuf::from(\"./preview_output.png\")")]
preview_output: PathBuf,
#[builder(default = "PreviewType::PREVIEW_NONE")]
preview_mode: PreviewType,
#[builder(default = "false")]
preview_noisy: bool,
#[builder(default = "1")]
preview_interval: i32,
prompt: String,
#[builder(default = "\"\".into()")]
negative_prompt: CLibString,
#[builder(default = "7.0")]
cfg_scale: f32,
#[builder(default = "3.5")]
guidance: f32,
#[builder(default = "0.75")]
strength: f32,
#[builder(default = "20.0")]
pm_style_strength: f32,
#[builder(default = "0.9")]
control_strength: f32,
#[builder(default = "512")]
height: i32,
#[builder(default = "512")]
width: i32,
#[builder(default = "SampleMethod::SAMPLE_METHOD_COUNT")]
sampling_method: SampleMethod,
#[builder(default = "0.")]
eta: f32,
#[builder(default = "20")]
steps: i32,
#[builder(default = "42")]
seed: i64,
#[builder(default = "1")]
batch_count: i32,
#[builder(default = "ClipSkip::Unspecified")]
clip_skip: ClipSkip,
#[builder(default = "false")]
canny: bool,
#[builder(default = "0.")]
slg_scale: f32,
#[builder(default = "vec![7, 8, 9]")]
skip_layer: Vec<i32>,
#[builder(default = "0.01")]
skip_layer_start: f32,
#[builder(default = "0.2")]
skip_layer_end: f32,
#[builder(default = "false")]
disable_auto_resize_ref_image: bool,
#[builder(default = "Self::cache_init()", private)]
cache: sd_cache_params_t,
#[builder(default = "CLibString::default()", private)]
scm_mask: CLibString,
}
impl ConfigBuilder {
fn validate(&self) -> Result<(), ConfigBuilderError> {
self.validate_output_dir()
}
fn validate_output_dir(&self) -> Result<(), ConfigBuilderError> {
let is_dir = self.output.as_ref().is_some_and(|val| val.is_dir());
let multiple_items = self.batch_count.as_ref().is_some_and(|val| *val > 1);
if is_dir == multiple_items {
Ok(())
} else {
Err(ConfigBuilderError::ValidationError(
"When batch_count > 1, output should point to folder and vice versa".to_owned(),
))
}
}
fn cache_init() -> sd_cache_params_t {
sd_cache_params_t {
mode: sd_cache_mode_t::SD_CACHE_DISABLED,
reuse_threshold: 1.0,
start_percent: 0.15,
end_percent: 0.95,
error_decay_rate: 1.0,
use_relative_threshold: true,
reset_error_on_compute: true,
Fn_compute_blocks: 8,
Bn_compute_blocks: 0,
residual_diff_threshold: 0.08,
max_warmup_steps: 8,
max_cached_steps: -1,
max_continuous_cached_steps: -1,
taylorseer_n_derivatives: 1,
taylorseer_skip_interval: 1,
scm_mask: null(),
scm_policy_dynamic: true,
}
}
pub fn no_caching(&mut self) -> &mut Self {
let mut cache = Self::cache_init();
cache.mode = sd_cache_mode_t::SD_CACHE_DISABLED;
self.cache = Some(cache);
self
}
pub fn ucache_caching(&mut self, params: UCacheParams) -> &mut Self {
let mut cache = Self::cache_init();
cache.mode = sd_cache_mode_t::SD_CACHE_UCACHE;
cache.reuse_threshold = params.threshold;
cache.start_percent = params.start;
cache.end_percent = params.end;
cache.error_decay_rate = params.decay;
cache.use_relative_threshold = params.relative;
cache.reset_error_on_compute = params.reset;
self.cache = Some(cache);
self
}
pub fn easy_cache_caching(&mut self, params: EasyCacheParams) -> &mut Self {
let mut cache = Self::cache_init();
cache.mode = sd_cache_mode_t::SD_CACHE_EASYCACHE;
cache.reuse_threshold = params.threshold;
cache.start_percent = params.start;
cache.end_percent = params.end;
self.cache = Some(cache);
self
}
pub fn db_cache_caching(&mut self, params: DbCacheParams) -> &mut Self {
let mut cache = Self::cache_init();
cache.mode = sd_cache_mode_t::SD_CACHE_DBCACHE;
cache.Fn_compute_blocks = params.fn_blocks;
cache.Bn_compute_blocks = params.bn_blocks;
cache.residual_diff_threshold = params.threshold;
cache.max_warmup_steps = params.warmup;
cache.scm_policy_dynamic = match params.scm_policy_dynamic {
ScmPolicy::Static => false,
ScmPolicy::Dynamic => true,
};
self.scm_mask = Some(CLibString::from(
params
.scm_mask
.to_vec_string(self.steps.unwrap_or_default()),
));
cache.scm_mask = self.scm_mask.as_ref().unwrap().as_ptr();
self.cache = Some(cache);
self
}
pub fn taylor_seer_caching(&mut self) -> &mut Self {
let mut cache = Self::cache_init();
cache.mode = sd_cache_mode_t::SD_CACHE_TAYLORSEER;
self.cache = Some(cache);
self
}
pub fn cache_dit_caching(&mut self, params: DbCacheParams) -> &mut Self {
self.db_cache_caching(params).cache.unwrap().mode = sd_cache_mode_t::SD_CACHE_CACHE_DIT;
self
}
}
impl From<Config> for ConfigBuilder {
fn from(value: Config) -> Self {
let mut builder = ConfigBuilder::default();
let mut cache = value.cache;
let scm_mask = value.scm_mask.clone();
cache.scm_mask = scm_mask.as_ptr();
builder
.pm_id_images_dir(value.pm_id_images_dir)
.init_img(value.init_img)
.mask_img(value.mask_img)
.control_image(value.control_image)
.ref_images(value.ref_images)
.output(value.output)
.prompt(value.prompt)
.negative_prompt(value.negative_prompt)
.cfg_scale(value.cfg_scale)
.strength(value.strength)
.pm_style_strength(value.pm_style_strength)
.control_strength(value.control_strength)
.height(value.height)
.width(value.width)
.sampling_method(value.sampling_method)
.steps(value.steps)
.seed(value.seed)
.batch_count(value.batch_count)
.clip_skip(value.clip_skip)
.slg_scale(value.slg_scale)
.skip_layer(value.skip_layer)
.skip_layer_start(value.skip_layer_start)
.skip_layer_end(value.skip_layer_end)
.canny(value.canny)
.disable_auto_resize_ref_image(value.disable_auto_resize_ref_image)
.preview_output(value.preview_output)
.preview_mode(value.preview_mode)
.preview_noisy(value.preview_noisy)
.preview_interval(value.preview_interval)
.cache(cache)
.scm_mask(scm_mask);
builder
}
}
#[derive(Debug, Clone, Default)]
struct CLibString(CString);
impl CLibString {
fn as_ptr(&self) -> *const c_char {
self.0.as_ptr()
}
}
impl From<&str> for CLibString {
fn from(value: &str) -> Self {
Self(CString::new(value).unwrap())
}
}
impl From<String> for CLibString {
fn from(value: String) -> Self {
Self(CString::new(value).unwrap())
}
}
#[derive(Debug, Clone, Default)]
struct CLibPath(CString);
impl CLibPath {
fn as_ptr(&self) -> *const c_char {
self.0.as_ptr()
}
}
impl From<PathBuf> for CLibPath {
fn from(value: PathBuf) -> Self {
Self(CString::new(value.to_str().unwrap_or_default()).unwrap())
}
}
impl From<&Path> for CLibPath {
fn from(value: &Path) -> Self {
Self(CString::new(value.to_str().unwrap_or_default()).unwrap())
}
}
impl From<&CLibPath> for PathBuf {
fn from(value: &CLibPath) -> Self {
PathBuf::from(value.0.to_str().unwrap())
}
}
fn output_files(path: &Path, batch_size: i32) -> Vec<PathBuf> {
let date = Local::now().format("%Y.%m.%d-%H.%M.%S");
if batch_size == 1 {
vec![path.into()]
} else {
(1..=batch_size)
.map(|id| path.join(format!("output_{date}_{id}.png")))
.collect()
}
}
unsafe fn upscale(
upscale_repeats: i32,
upscaler_ctx: Option<*mut upscaler_ctx_t>,
data: sd_image_t,
) -> Result<sd_image_t, DiffusionError> {
unsafe {
match upscaler_ctx {
Some(upscaler_ctx) => {
let upscale_factor = 4; let mut current_image = data;
for _ in 0..upscale_repeats {
let upscaled_image =
diffusion_rs_sys::upscale(upscaler_ctx, current_image, upscale_factor);
if upscaled_image.data.is_null() {
return Err(DiffusionError::Upscaler);
}
free(current_image.data as *mut c_void);
current_image = upscaled_image;
}
Ok(current_image)
}
None => Ok(data),
}
}
}
pub fn gen_img_with_progress(
config: &Config,
model_config: &mut ModelConfig,
sender: Sender<Progress>,
) -> Result<(), DiffusionError> {
gen_img_maybe_progress(config, model_config, Some(sender))
}
pub fn gen_img(config: &Config, model_config: &mut ModelConfig) -> Result<(), DiffusionError> {
gen_img_maybe_progress(config, model_config, None)
}
fn gen_img_maybe_progress(
config: &Config,
model_config: &mut ModelConfig,
mut sender: Option<Sender<Progress>>,
) -> Result<(), DiffusionError> {
let prompt: CLibString = CLibString::from(config.prompt.as_str());
let files = output_files(&config.output, config.batch_count);
unsafe {
let has_init_image = config.init_img.exists();
let has_mask_image = config.mask_img.exists();
let is_decode_only = !has_init_image;
let sd_ctx = model_config.diffusion_ctx(is_decode_only);
let upscaler_ctx = model_config.upscaler_ctx();
let mut init_image = sd_image_t {
width: 0,
height: 0,
channel: 3,
data: std::ptr::null_mut(),
};
let mut mask_image = sd_image_t {
width: config.width as u32,
height: config.height as u32,
channel: 1,
data: null_mut(),
};
let mut layers = config.skip_layer.clone();
let guidance = sd_guidance_params_t {
txt_cfg: config.cfg_scale,
img_cfg: config.cfg_scale,
distilled_guidance: config.guidance,
slg: sd_slg_params_t {
layers: layers.as_mut_ptr(),
layer_count: config.skip_layer.len(),
layer_start: config.skip_layer_start,
layer_end: config.skip_layer_end,
scale: config.slg_scale,
},
};
let scheduler = if model_config.scheduler == Scheduler::SCHEDULER_COUNT {
sd_get_default_scheduler(sd_ctx, config.sampling_method)
} else {
model_config.scheduler
};
let sample_method = if config.sampling_method == SampleMethod::SAMPLE_METHOD_COUNT {
sd_get_default_sample_method(sd_ctx)
} else {
config.sampling_method
};
let sample_params = sd_sample_params_t {
guidance,
sample_method,
sample_steps: config.steps,
eta: config.eta,
scheduler,
shifted_timestep: model_config.timestep_shift,
custom_sigmas: model_config.sigmas.as_mut_ptr(),
custom_sigmas_count: model_config.sigmas.len() as i32,
flow_shift: model_config.flow_shift,
};
let control_image = sd_image_t {
width: 0,
height: 0,
channel: 3,
data: null_mut(),
};
let vae_tiling_params = sd_tiling_params_t {
enabled: model_config.vae_tiling,
tile_size_x: model_config.vae_tile_size.0,
tile_size_y: model_config.vae_tile_size.1,
target_overlap: model_config.vae_tile_overlap,
rel_size_x: model_config.vae_relative_tile_size.0,
rel_size_y: model_config.vae_relative_tile_size.1,
};
let pm_params = sd_pm_params_t {
id_images: null_mut(),
id_images_count: 0,
id_embed_path: model_config.pm_id_embed_path.as_ptr(),
style_strength: config.pm_style_strength,
};
let mut image_buffer: Vec<u8> = Vec::new();
let mut mask_buffer: Vec<u8> = Vec::new();
if has_init_image {
let img = image::open(&config.init_img)?;
image_buffer = img.to_rgb8().into_raw();
init_image = sd_image_t {
width: img.width(),
height: img.height(),
channel: 3,
data: image_buffer.as_mut_ptr(),
}
}
if has_mask_image {
let img = image::open(&config.mask_img)?;
mask_buffer = img.to_luma8().into_raw();
mask_image = sd_image_t {
width: img.width(),
height: img.height(),
channel: 1,
data: mask_buffer.as_mut_ptr(),
}
}
if !image_buffer.is_empty() && mask_buffer.is_empty() {
let img: ImageBuffer<image::Luma<u8>, Vec<u8>> =
ImageBuffer::from_pixel(init_image.width, init_image.height, image::Luma([255]));
mask_buffer = img.into_raw();
mask_image = sd_image_t {
width: init_image.width,
height: init_image.height,
channel: 1,
data: mask_buffer.as_mut_ptr(),
}
}
let mut ref_image_list = Vec::new();
let mut ref_pixel_storage = Vec::new();
for ref_path in &config.ref_images {
if ref_path.exists() {
let img = image::open(ref_path)?;
let image_data = img.to_rgb8().into_raw();
ref_pixel_storage.push(image_data);
let storage_ref = ref_pixel_storage.last_mut().unwrap();
ref_image_list.push(sd_image_t {
width: img.width(),
height: img.height(),
channel: 3,
data: storage_ref.as_mut_ptr(),
});
}
}
let num_ref_images = ref_image_list.len();
let ref_image_ptr = if num_ref_images > 0 {
ref_image_list.as_mut_ptr()
} else {
null_mut()
};
unsafe extern "C" fn save_preview_local(
_step: ::std::os::raw::c_int,
_frame_count: ::std::os::raw::c_int,
frames: *mut sd_image_t,
_is_noisy: bool,
data: *mut ::std::os::raw::c_void,
) {
unsafe {
let path = &*data.cast::<PathBuf>();
let _ = save_img(*frames, path, None);
}
}
if config.preview_mode != PreviewType::PREVIEW_NONE {
let data = &config.preview_output as *const PathBuf;
sd_set_preview_callback(
Some(save_preview_local),
config.preview_mode,
config.preview_interval,
!config.preview_noisy,
config.preview_noisy,
data as *mut c_void,
);
}
if sender.is_some() {
unsafe extern "C" fn progress_callback(
step: ::std::os::raw::c_int,
steps: ::std::os::raw::c_int,
time: f32,
data: *mut ::std::os::raw::c_void,
) {
unsafe {
let sender = &*data.cast::<Option<Sender<Progress>>>();
if let Some(sender) = sender {
let _ = sender.send(Progress { step, steps, time });
}
}
}
let sender_ptr: *mut c_void = &mut sender as *mut _ as *mut c_void;
sd_set_progress_callback(Some(progress_callback), sender_ptr);
}
let loras: Vec<sd_lora_t> = model_config
.lora_models
.iter()
.map(|(c_path, spec)| sd_lora_t {
is_high_noise: spec.is_high_noise,
multiplier: spec.multiplier,
path: c_path.as_ptr(),
})
.collect();
let sd_img_gen_params = sd_img_gen_params_t {
prompt: prompt.as_ptr(),
negative_prompt: config.negative_prompt.as_ptr(),
clip_skip: config.clip_skip as i32,
init_image,
ref_images: ref_image_ptr,
ref_images_count: num_ref_images as i32,
increase_ref_index: false,
mask_image,
width: config.width,
height: config.height,
sample_params,
strength: config.strength,
seed: config.seed,
batch_count: config.batch_count,
control_image,
control_strength: config.control_strength,
pm_params,
vae_tiling_params,
auto_resize_ref_image: config.disable_auto_resize_ref_image,
cache: config.cache,
loras: loras.as_ptr(),
lora_count: loras.len() as u32,
};
let params_str = CString::from_raw(sd_img_gen_params_to_str(&sd_img_gen_params))
.into_string()
.unwrap();
let slice = generate_image(sd_ctx, &sd_img_gen_params);
let ret = {
if slice.is_null() {
return Err(DiffusionError::Forward);
}
for (img, path) in slice::from_raw_parts(slice, config.batch_count as usize)
.iter()
.zip(files)
{
match upscale(model_config.upscale_repeats, upscaler_ctx, *img) {
Ok(img) => save_img(img, &path, Some(¶ms_str))?,
Err(err) => {
return Err(err);
}
}
}
Ok(())
};
free(slice as *mut c_void);
ret
}
}
fn save_img(img: sd_image_t, path: &Path, params: Option<&str>) -> Result<(), DiffusionError> {
let len = (img.width * img.height * img.channel) as usize;
let buffer = unsafe { slice::from_raw_parts(img.data, len).to_vec() };
let save_state = ImageBuffer::from_raw(img.width, img.height, buffer).map(|img| {
RgbImage::from(img)
.save(path)
.map_err(DiffusionError::StoreImages)
});
if let Some(Err(err)) = save_state {
return Err(err);
}
if let Some(params) = params {
let mut metadata = Metadata::new();
metadata.set_tag(ExifTag::ImageDescription(params.to_string()));
metadata.write_to_file(path)?;
}
Ok(())
}
#[cfg(test)]
mod tests {
use image::{DynamicImage, ImageBuffer, Rgba};
use std::path::PathBuf;
use crate::{
api::{ConfigBuilderError, ModelConfigBuilder},
util::download_file_hf_hub,
};
use super::{ConfigBuilder, gen_img};
#[test]
fn test_required_args_txt2img() {
assert!(ConfigBuilder::default().build().is_err());
assert!(ModelConfigBuilder::default().build().is_err());
ModelConfigBuilder::default()
.model(PathBuf::from("./test.ckpt"))
.build()
.unwrap();
ConfigBuilder::default()
.prompt("a lovely cat driving a sport car")
.build()
.unwrap();
assert!(matches!(
ConfigBuilder::default()
.prompt("a lovely cat driving a sport car")
.batch_count(10)
.build(),
Err(ConfigBuilderError::ValidationError(_))
));
ConfigBuilder::default()
.prompt("a lovely cat driving a sport car")
.build()
.unwrap();
ConfigBuilder::default()
.prompt("a lovely duck drinking water from a bottle")
.batch_count(2)
.output(PathBuf::from("./"))
.build()
.unwrap();
}
#[ignore]
#[test]
fn test_img2img_gen() {
let model_path =
download_file_hf_hub("CompVis/stable-diffusion-v-1-4-original", "sd-v1-4.ckpt")
.unwrap();
let gen_img_output = "./output_img.png";
let config = ConfigBuilder::default()
.prompt("A high quality 3d texture")
.output(PathBuf::from(gen_img_output))
.batch_count(1)
.build()
.unwrap();
let mut model_config = ModelConfigBuilder::default()
.model(model_path)
.build()
.unwrap();
gen_img(&config, &mut model_config).unwrap();
let mut cond = ImageBuffer::new(512, 512);
for (x, y, pixel) in cond.enumerate_pixels_mut() {
let r = (x as f32 / 512.0 * 255.0) as u8;
let g = (y as f32 / 512.0 * 255.0) as u8;
let b = 127;
*pixel = Rgba([r, g, b, 255]);
}
let cond_path = "test_cond_image.png";
DynamicImage::ImageRgba8(cond)
.save(cond_path)
.expect("Failed to save reference image");
let refine_prompt = "PBR texture map, matching the lighting and micro-detail density of the reference image.";
let img2img_config = ConfigBuilder::default()
.prompt(refine_prompt)
.output(PathBuf::from("./output_img_ref.png"))
.ref_images(vec![PathBuf::from(cond_path)])
.init_img(PathBuf::from(gen_img_output))
.batch_count(1)
.build()
.unwrap();
gen_img(&img2img_config, &mut model_config).unwrap();
gen_img(&config, &mut model_config).unwrap();
}
#[ignore]
#[test]
fn test_img_gen() {
let model_path =
download_file_hf_hub("CompVis/stable-diffusion-v-1-4-original", "sd-v1-4.ckpt")
.unwrap();
let upscaler_path = download_file_hf_hub(
"ximso/RealESRGAN_x4plus_anime_6B",
"RealESRGAN_x4plus_anime_6B.pth",
)
.unwrap();
let config = ConfigBuilder::default()
.prompt("a lovely duck drinking water from a bottle")
.output(PathBuf::from("./output_1.png"))
.batch_count(1)
.build()
.unwrap();
let mut model_config = ModelConfigBuilder::default()
.model(model_path)
.upscale_model(upscaler_path)
.upscale_repeats(1)
.build()
.unwrap();
gen_img(&config, &mut model_config).unwrap();
let config2 = ConfigBuilder::from(config.clone())
.prompt("a lovely duck drinking water from a straw")
.output(PathBuf::from("./output_2.png"))
.build()
.unwrap();
gen_img(&config2, &mut model_config).unwrap();
let config3 = ConfigBuilder::from(config)
.prompt("a lovely dog drinking water from a starbucks cup")
.batch_count(2)
.output(PathBuf::from("./"))
.build()
.unwrap();
gen_img(&config3, &mut model_config).unwrap();
}
}