use anyhow::{bail, Result};
use candle_core::{DType, Device, Module, Tensor, D};
use candle_transformers::models::stable_diffusion;
use candle_transformers::models::stable_diffusion::schedulers::PredictionType;
use mold_core::{GenerateRequest, GenerateResponse, ImageData, ModelPaths, Scheduler};
use std::collections::HashMap;
use std::path::{Path, PathBuf};
use std::sync::{Arc, Mutex};
use std::time::Instant;
use crate::cache::{
cfg_prompt_cache_key, clear_cache, get_or_insert_cached_tensor, image_size_cache_key,
latent_size_cache_key, restore_cached_tensor, CachedTensor, CfgPromptCacheKey,
ImageSizeCacheKey, LatentSizeCacheKey, LruCache, DEFAULT_IMAGE_CACHE_CAPACITY,
DEFAULT_PROMPT_CACHE_CAPACITY,
};
use crate::cfg_plus_ddim::DdimAlphaSchedule;
use crate::device::{check_memory_budget, memory_status_string, preflight_memory_check};
use crate::engine::{cfg_active, rand_seed, resolve_cfg_plus, InferenceEngine, LoadStrategy};
use crate::engine_base::EngineBase;
use crate::image::{build_output_metadata, encode_image};
use crate::progress::{ProgressCallback, ProgressEvent};
struct LoadedSDXL {
unet: Option<stable_diffusion::unet_2d::UNet2DConditionModel>,
vae: stable_diffusion::vae::AutoEncoderKL,
clip_l: stable_diffusion::clip::ClipTextTransformer,
clip_g: stable_diffusion::clip::ClipTextTransformer,
tokenizer_l: Arc<tokenizers::Tokenizer>,
tokenizer_g: Arc<tokenizers::Tokenizer>,
sd_config: stable_diffusion::StableDiffusionConfig,
device: Device,
clip_device: Device,
dtype: DType,
vae_dtype: DType,
}
pub struct SDXLEngine {
base: EngineBase<LoadedSDXL>,
scheduler: Scheduler,
is_turbo: bool,
shared_pool: Option<Arc<Mutex<crate::shared_pool::SharedPool>>>,
prompt_cache: Mutex<LruCache<CfgPromptCacheKey, CachedTensor>>,
source_latent_cache: Mutex<LruCache<ImageSizeCacheKey, CachedTensor>>,
mask_cache: Mutex<LruCache<LatentSizeCacheKey, CachedTensor>>,
pending_placement: Option<mold_core::types::DevicePlacement>,
pub(crate) single_file_path: Option<PathBuf>,
pending_loras: Vec<mold_core::LoraWeight>,
active_lora_fingerprint: Vec<(String, u64)>,
}
fn lora_stack_fingerprint(loras: &[mold_core::LoraWeight]) -> Vec<(String, u64)> {
loras
.iter()
.map(|w| (w.path.clone(), w.scale.to_bits()))
.collect()
}
const VAE_SCALE_STANDARD: f64 = 0.18215;
const VAE_SCALE_TURBO: f64 = 0.13025;
fn resolve_sdxl_vae_dtype(default_dtype: DType, single_file: bool) -> DType {
let default = if single_file {
DType::F32
} else {
default_dtype
};
crate::device::resolve_vae_dtype(default)
}
impl SDXLEngine {
pub fn new(
model_name: String,
paths: ModelPaths,
scheduler: Scheduler,
is_turbo: bool,
load_strategy: LoadStrategy,
gpu_ordinal: usize,
shared_pool: Option<Arc<Mutex<crate::shared_pool::SharedPool>>>,
) -> Self {
Self {
base: EngineBase::new(model_name, paths, load_strategy, gpu_ordinal),
scheduler,
is_turbo,
shared_pool,
prompt_cache: Mutex::new(LruCache::new(DEFAULT_PROMPT_CACHE_CAPACITY)),
source_latent_cache: Mutex::new(LruCache::new(DEFAULT_IMAGE_CACHE_CAPACITY)),
mask_cache: Mutex::new(LruCache::new(DEFAULT_IMAGE_CACHE_CAPACITY)),
pending_placement: None,
single_file_path: None,
pending_loras: Vec::new(),
active_lora_fingerprint: Vec::new(),
}
}
#[allow(clippy::too_many_arguments)]
pub fn from_single_file(
model_name: String,
single_file_path: PathBuf,
clip_l_tokenizer: PathBuf,
clip_g_tokenizer: PathBuf,
scheduler: Scheduler,
is_turbo: bool,
load_strategy: LoadStrategy,
gpu_ordinal: usize,
shared_pool: Option<Arc<Mutex<crate::shared_pool::SharedPool>>>,
) -> Result<Self> {
if !single_file_path.exists() {
bail!(
"single-file checkpoint not found: {}",
single_file_path.display()
);
}
let bundle = crate::loader::single_file::load(
&single_file_path,
mold_catalog::families::Family::Sdxl,
)?;
let _remap = crate::loader::sdxl_keys::build_sdxl_remap(&bundle)?;
let paths = ModelPaths {
transformer: single_file_path.clone(),
transformer_shards: Vec::new(),
vae: single_file_path.clone(),
spatial_upscaler: None,
temporal_upscaler: None,
distilled_lora: None,
t5_encoder: None,
clip_encoder: Some(single_file_path.clone()),
t5_tokenizer: None,
clip_tokenizer: Some(clip_l_tokenizer),
clip_encoder_2: Some(single_file_path.clone()),
clip_tokenizer_2: Some(clip_g_tokenizer),
text_encoder_files: Vec::new(),
text_tokenizer: None,
decoder: None,
};
Ok(Self {
base: EngineBase::new(model_name, paths, load_strategy, gpu_ordinal),
scheduler,
is_turbo,
shared_pool,
prompt_cache: Mutex::new(LruCache::new(DEFAULT_PROMPT_CACHE_CAPACITY)),
source_latent_cache: Mutex::new(LruCache::new(DEFAULT_IMAGE_CACHE_CAPACITY)),
mask_cache: Mutex::new(LruCache::new(DEFAULT_IMAGE_CACHE_CAPACITY)),
pending_placement: None,
single_file_path: Some(single_file_path),
pending_loras: Vec::new(),
active_lora_fingerprint: Vec::new(),
})
}
fn validate_paths(
&self,
) -> Result<(
std::path::PathBuf,
std::path::PathBuf,
std::path::PathBuf,
std::path::PathBuf,
)> {
let clip_encoder = self
.base
.paths
.clip_encoder
.as_ref()
.ok_or_else(|| anyhow::anyhow!("CLIP-L encoder path required for SDXL models"))?
.clone();
let clip_tokenizer = self
.base
.paths
.clip_tokenizer
.as_ref()
.ok_or_else(|| anyhow::anyhow!("CLIP-L tokenizer path required for SDXL models"))?
.clone();
let clip_encoder_2 = self
.base
.paths
.clip_encoder_2
.as_ref()
.ok_or_else(|| anyhow::anyhow!("CLIP-G encoder path required for SDXL models"))?
.clone();
let clip_tokenizer_2 = self
.base
.paths
.clip_tokenizer_2
.as_ref()
.ok_or_else(|| anyhow::anyhow!("CLIP-G tokenizer path required for SDXL models"))?
.clone();
for (label, path) in [
("transformer (UNet)", &self.base.paths.transformer),
("vae", &self.base.paths.vae),
("clip_encoder (CLIP-L)", &clip_encoder),
("clip_tokenizer (CLIP-L)", &clip_tokenizer),
("clip_encoder_2 (CLIP-G)", &clip_encoder_2),
("clip_tokenizer_2 (CLIP-G)", &clip_tokenizer_2),
] {
if !path.exists() {
bail!("{label} file not found: {}", path.display());
}
}
Ok((
clip_encoder,
clip_tokenizer,
clip_encoder_2,
clip_tokenizer_2,
))
}
fn load_clip_tokenizer(
&self,
clip_tokenizer: &std::path::Path,
label: &str,
) -> Result<Arc<tokenizers::Tokenizer>> {
if let Some(ref pool) = self.shared_pool {
return pool.lock().unwrap().load_tokenizer(clip_tokenizer);
}
Ok(Arc::new(
tokenizers::Tokenizer::from_file(clip_tokenizer)
.map_err(|e| anyhow::anyhow!("failed to load {label} tokenizer: {e}"))?,
))
}
fn sd_config(&self) -> stable_diffusion::StableDiffusionConfig {
if self.is_turbo {
stable_diffusion::StableDiffusionConfig::sdxl_turbo(None, None, None)
} else {
stable_diffusion::StableDiffusionConfig::sdxl(None, None, None)
}
}
fn reload_unet_if_needed(&mut self) -> Result<()> {
let needs_reload = self
.base
.loaded
.as_ref()
.map(|l| l.unet.is_none())
.unwrap_or(false);
if needs_reload {
let sd_config = self.sd_config();
let loaded = self.base.loaded.as_ref().unwrap();
let device = loaded.device.clone();
let dtype = loaded.dtype;
let _ = loaded;
self.base.progress.stage_start("Reloading UNet (GPU)");
let reload_start = Instant::now();
let unet = self.build_unet_for_strategy(&sd_config, &device, dtype)?;
self.base.loaded.as_mut().unwrap().unet = Some(unet);
self.base
.progress
.stage_done("Reloading UNet (GPU)", reload_start.elapsed());
}
Ok(())
}
fn build_unet_for_strategy(
&self,
sd_config: &stable_diffusion::StableDiffusionConfig,
device: &Device,
dtype: DType,
) -> Result<stable_diffusion::unet_2d::UNet2DConditionModel> {
let has_lora = !self.pending_loras.is_empty();
if let Some(single_file) = self.single_file_path.as_ref() {
let remap = Self::load_sdxl_remap(single_file)?;
if has_lora {
self.build_unet_single_file_with_lora(single_file, &remap, sd_config, device, dtype)
} else {
Self::build_unet_single_file(single_file, &remap, sd_config, device, dtype)
}
} else if has_lora {
self.build_unet_diffusers_with_lora(sd_config, device, dtype)
} else {
Ok(sd_config.build_unet(&self.base.paths.transformer, device, 4, false, dtype)?)
}
}
fn build_unet_diffusers_with_lora(
&self,
sd_config: &stable_diffusion::StableDiffusionConfig,
device: &Device,
dtype: DType,
) -> Result<stable_diffusion::unet_2d::UNet2DConditionModel> {
use candle_core::safetensors::MmapedSafetensors;
use candle_nn::VarBuilder;
let st = unsafe { MmapedSafetensors::multi(&[&self.base.paths.transformer])? };
struct MmapBackend {
st: MmapedSafetensors,
}
impl candle_nn::var_builder::SimpleBackend for MmapBackend {
fn get(
&self,
_s: candle_core::Shape,
name: &str,
_h: candle_nn::Init,
dtype: DType,
dev: &Device,
) -> candle_core::Result<Tensor> {
let t = self.st.load(name, dev)?;
if t.dtype() != dtype {
t.to_dtype(dtype)
} else {
Ok(t)
}
}
fn get_unchecked(
&self,
name: &str,
dtype: DType,
dev: &Device,
) -> candle_core::Result<Tensor> {
let t = self.st.load(name, dev)?;
if t.dtype() != dtype {
t.to_dtype(dtype)
} else {
Ok(t)
}
}
fn contains_tensor(&self, name: &str) -> bool {
self.st.get(name).is_ok()
}
}
let inner: Box<dyn candle_nn::var_builder::SimpleBackend> = Box::new(MmapBackend { st });
let wrapped = self.wrap_with_loras(inner)?;
let vb = VarBuilder::from_backend(wrapped, dtype, device.clone());
Ok(stable_diffusion::unet_2d::UNet2DConditionModel::new(
vb,
4,
4,
false,
sd_config.unet().clone(),
)?)
}
fn build_unet_single_file_with_lora(
&self,
single_file: &std::path::Path,
remap: &crate::loader::SdxlRemap,
sd_config: &stable_diffusion::StableDiffusionConfig,
device: &Device,
dtype: DType,
) -> Result<stable_diffusion::unet_2d::UNet2DConditionModel> {
use crate::loader::SingleFileBackend;
use candle_nn::VarBuilder;
let backend = SingleFileBackend::from_sdxl_unet(single_file, remap)?;
let inner: Box<dyn candle_nn::var_builder::SimpleBackend> = Box::new(backend);
let wrapped = self.wrap_with_loras(inner)?;
let vb = VarBuilder::from_backend(wrapped, dtype, device.clone());
Ok(stable_diffusion::unet_2d::UNet2DConditionModel::new(
vb,
4,
4,
false,
sd_config.unet().clone(),
)?)
}
fn wrap_with_loras(
&self,
inner: Box<dyn candle_nn::var_builder::SimpleBackend>,
) -> Result<Box<dyn candle_nn::var_builder::SimpleBackend>> {
let adapters = super::lora::load_lora_adapters(&self.pending_loras, &self.base.progress)?;
let specs: Vec<super::lora::SdxlLoraSpec<'_>> = adapters
.iter()
.zip(self.pending_loras.iter())
.map(|(adapter, w)| super::lora::SdxlLoraSpec {
adapter: adapter.as_ref(),
scale: w.scale,
path_hash: super::lora::lora_path_hash(&w.path),
})
.collect();
super::lora::wrap_backend_with_lora(inner, &specs, &self.base.progress, None)
}
fn build_vae_for_strategy(
&self,
sd_config: &stable_diffusion::StableDiffusionConfig,
device: &Device,
dtype: DType,
) -> Result<stable_diffusion::vae::AutoEncoderKL> {
if let Some(single_file) = self.single_file_path.as_ref() {
let remap = Self::load_sdxl_remap(single_file)?;
Self::build_vae_single_file(single_file, &remap, sd_config, device, dtype)
} else {
self.build_vae_diffusers(sd_config, device, dtype)
}
}
#[cfg(test)]
fn load_vae_cpu_tensors(&self) -> Result<Option<Arc<HashMap<String, Tensor>>>> {
self.load_vae_cpu_tensors_for_path(&self.base.paths.vae)
}
fn load_vae_cpu_tensors_for_path(
&self,
vae_path: &Path,
) -> Result<Option<Arc<HashMap<String, Tensor>>>> {
let Some(shared_pool) = &self.shared_pool else {
return Ok(None);
};
shared_pool
.lock()
.unwrap()
.load_safetensors_cpu_tensors(std::slice::from_ref(&vae_path))
}
fn load_vae_var_builder<'a>(
&self,
vae_path: &Path,
dtype: DType,
device: &Device,
component: &str,
) -> Result<candle_nn::VarBuilder<'a>> {
if let Some(tensors) = self.load_vae_cpu_tensors_for_path(vae_path)? {
return Ok(crate::encoders::park::varbuilder_from_parked(
tensors.as_ref(),
dtype,
device,
));
}
crate::weight_loader::load_safetensors_with_progress(
std::slice::from_ref(&vae_path),
dtype,
device,
component,
&self.base.progress,
)
}
fn build_vae_diffusers(
&self,
sd_config: &stable_diffusion::StableDiffusionConfig,
device: &Device,
dtype: DType,
) -> Result<stable_diffusion::vae::AutoEncoderKL> {
let vb = self.load_vae_var_builder(&self.base.paths.vae, dtype, device, "VAE")?;
Ok(stable_diffusion::vae::AutoEncoderKL::new(
vb,
3,
3,
sd_config.autoencoder().clone(),
)?)
}
pub fn load(&mut self) -> Result<()> {
if self.base.loaded.is_some() {
return Ok(());
}
if self.base.load_strategy == LoadStrategy::Sequential {
return Ok(());
}
let (clip_encoder, clip_tokenizer, clip_encoder_2, clip_tokenizer_2) =
self.validate_paths()?;
tracing::info!(model = %self.base.model_name, "loading SDXL model components...");
let device = crate::device::create_device(self.base.gpu_ordinal, &self.base.progress)?;
let dtype = if crate::device::is_gpu(&device) {
DType::F16
} else {
DType::F32
};
let sd_config = self.sd_config();
let tier1 = self
.pending_placement
.as_ref()
.map(|p| p.text_encoders)
.unwrap_or_default();
let clip_device = crate::device::resolve_device(Some(tier1), || Ok(device.clone()))?;
let vae_dtype = resolve_sdxl_vae_dtype(dtype, self.single_file_path.is_some());
let (unet, vae, clip_l, clip_g) = if let Some(single_file) = self.single_file_path.clone() {
self.load_components_single_file(
&single_file,
&sd_config,
&device,
&clip_device,
dtype,
vae_dtype,
)?
} else {
self.load_components_diffusers(
&clip_encoder,
&clip_encoder_2,
&sd_config,
&device,
&clip_device,
dtype,
vae_dtype,
)?
};
let tokenizer_l = self.load_clip_tokenizer(&clip_tokenizer, "CLIP-L")?;
let tokenizer_g = self.load_clip_tokenizer(&clip_tokenizer_2, "CLIP-G")?;
self.base.loaded = Some(LoadedSDXL {
unet: Some(unet),
vae,
clip_l,
clip_g,
tokenizer_l,
tokenizer_g,
sd_config,
device,
clip_device,
dtype,
vae_dtype,
});
tracing::info!(model = %self.base.model_name, "all SDXL components loaded successfully");
Ok(())
}
#[allow(clippy::too_many_arguments)]
fn load_components_diffusers(
&mut self,
clip_encoder: &std::path::Path,
clip_encoder_2: &std::path::Path,
sd_config: &stable_diffusion::StableDiffusionConfig,
device: &Device,
clip_device: &Device,
dtype: DType,
vae_dtype: DType,
) -> Result<(
stable_diffusion::unet_2d::UNet2DConditionModel,
stable_diffusion::vae::AutoEncoderKL,
stable_diffusion::clip::ClipTextTransformer,
stable_diffusion::clip::ClipTextTransformer,
)> {
self.base.progress.stage_start("Loading UNet (GPU)");
let unet_start = Instant::now();
let unet = sd_config.build_unet(&self.base.paths.transformer, device, 4, false, dtype)?;
self.base
.progress
.stage_done("Loading UNet (GPU)", unet_start.elapsed());
self.base.progress.stage_start("Loading VAE (GPU)");
let vae_start = Instant::now();
let vae = self.build_vae_diffusers(sd_config, device, vae_dtype)?;
self.base
.progress
.stage_done("Loading VAE (GPU)", vae_start.elapsed());
self.base.progress.stage_start("Loading CLIP-L encoder");
let clip_l_start = Instant::now();
let clip_l = stable_diffusion::build_clip_transformer(
&sd_config.clip,
clip_encoder,
clip_device,
DType::F32,
)?;
self.base
.progress
.stage_done("Loading CLIP-L encoder", clip_l_start.elapsed());
self.base.progress.stage_start("Loading CLIP-G encoder");
let clip_g_start = Instant::now();
let clip2_config = sd_config
.clip2
.as_ref()
.ok_or_else(|| anyhow::anyhow!("SDXL config missing clip2 configuration"))?;
let clip_g = stable_diffusion::build_clip_transformer(
clip2_config,
clip_encoder_2,
clip_device,
DType::F32,
)?;
self.base
.progress
.stage_done("Loading CLIP-G encoder", clip_g_start.elapsed());
Ok((unet, vae, clip_l, clip_g))
}
fn load_components_single_file(
&mut self,
single_file: &std::path::Path,
sd_config: &stable_diffusion::StableDiffusionConfig,
device: &Device,
clip_device: &Device,
dtype: DType,
vae_dtype: DType,
) -> Result<(
stable_diffusion::unet_2d::UNet2DConditionModel,
stable_diffusion::vae::AutoEncoderKL,
stable_diffusion::clip::ClipTextTransformer,
stable_diffusion::clip::ClipTextTransformer,
)> {
let remap = Self::load_sdxl_remap(single_file)?;
self.base.progress.stage_start("Loading UNet (single-file)");
let unet_start = Instant::now();
let unet = Self::build_unet_single_file(single_file, &remap, sd_config, device, dtype)?;
self.base
.progress
.stage_done("Loading UNet (single-file)", unet_start.elapsed());
self.base.progress.stage_start("Loading VAE (single-file)");
let vae_start = Instant::now();
let vae = Self::build_vae_single_file(single_file, &remap, sd_config, device, vae_dtype)?;
self.base
.progress
.stage_done("Loading VAE (single-file)", vae_start.elapsed());
self.base
.progress
.stage_start("Loading CLIP-L (single-file)");
let clip_l_start = Instant::now();
let clip_l =
Self::build_clip_l_single_file(single_file, &remap, &sd_config.clip, clip_device)?;
self.base
.progress
.stage_done("Loading CLIP-L (single-file)", clip_l_start.elapsed());
self.base
.progress
.stage_start("Loading CLIP-G (single-file)");
let clip_g_start = Instant::now();
let clip2_config = sd_config
.clip2
.as_ref()
.ok_or_else(|| anyhow::anyhow!("SDXL config missing clip2 configuration"))?;
let clip_g =
Self::build_clip_g_single_file(single_file, &remap, clip2_config, clip_device)?;
self.base
.progress
.stage_done("Loading CLIP-G (single-file)", clip_g_start.elapsed());
Ok((unet, vae, clip_l, clip_g))
}
fn load_sdxl_remap(single_file: &std::path::Path) -> Result<crate::loader::SdxlRemap> {
use crate::loader::{build_sdxl_remap, single_file as single_file_loader};
use mold_catalog::families::Family;
let bundle = single_file_loader::load(single_file, Family::Sdxl)
.map_err(|e| anyhow::anyhow!("partition single-file SDXL checkpoint: {e}"))?;
build_sdxl_remap(&bundle)
.map_err(|e| anyhow::anyhow!("build SDXL diffusers→A1111 remap: {e}"))
}
fn build_unet_single_file(
single_file: &std::path::Path,
remap: &crate::loader::SdxlRemap,
sd_config: &stable_diffusion::StableDiffusionConfig,
device: &Device,
dtype: DType,
) -> Result<stable_diffusion::unet_2d::UNet2DConditionModel> {
use crate::loader::SingleFileBackend;
use candle_nn::VarBuilder;
let backend = SingleFileBackend::from_sdxl_unet(single_file, remap)?;
let vb = VarBuilder::from_backend(Box::new(backend), dtype, device.clone());
Ok(stable_diffusion::unet_2d::UNet2DConditionModel::new(
vb,
4,
4,
false,
sd_config.unet().clone(),
)?)
}
fn build_vae_single_file(
single_file: &std::path::Path,
remap: &crate::loader::SdxlRemap,
sd_config: &stable_diffusion::StableDiffusionConfig,
device: &Device,
dtype: DType,
) -> Result<stable_diffusion::vae::AutoEncoderKL> {
use crate::loader::SingleFileBackend;
use candle_nn::VarBuilder;
let backend = SingleFileBackend::from_sdxl_vae(single_file, remap)?;
let vb = VarBuilder::from_backend(Box::new(backend), dtype, device.clone());
Ok(stable_diffusion::vae::AutoEncoderKL::new(
vb,
3,
3,
sd_config.autoencoder().clone(),
)?)
}
fn build_clip_l_single_file(
single_file: &std::path::Path,
remap: &crate::loader::SdxlRemap,
clip_config: &stable_diffusion::clip::Config,
clip_device: &Device,
) -> Result<stable_diffusion::clip::ClipTextTransformer> {
use crate::loader::SingleFileBackend;
use candle_nn::VarBuilder;
let backend = SingleFileBackend::from_sdxl_clip_l(single_file, remap)?;
let vb = VarBuilder::from_backend(Box::new(backend), DType::F32, clip_device.clone());
Ok(stable_diffusion::clip::ClipTextTransformer::new(
vb,
clip_config,
)?)
}
fn build_clip_g_single_file(
single_file: &std::path::Path,
remap: &crate::loader::SdxlRemap,
clip_config: &stable_diffusion::clip::Config,
clip_device: &Device,
) -> Result<stable_diffusion::clip::ClipTextTransformer> {
use crate::loader::SingleFileBackend;
use candle_nn::VarBuilder;
let backend = SingleFileBackend::from_sdxl_clip_g(single_file, remap)?;
let vb = VarBuilder::from_backend(Box::new(backend), DType::F32, clip_device.clone());
Ok(stable_diffusion::clip::ClipTextTransformer::new(
vb,
clip_config,
)?)
}
fn tokenize(
tokenizer: &tokenizers::Tokenizer,
prompt: &str,
max_len: usize,
device: &Device,
) -> Result<Tensor> {
let encoding = tokenizer
.encode(prompt, true)
.map_err(|e| anyhow::anyhow!("tokenization failed: {e}"))?;
let mut ids = encoding.get_ids().to_vec();
ids.truncate(max_len);
while ids.len() < max_len {
ids.push(0);
}
let ids = ids.into_iter().map(|i| i as i64).collect::<Vec<_>>();
Ok(Tensor::new(ids, device)?.unsqueeze(0)?)
}
#[allow(clippy::too_many_arguments)]
#[allow(clippy::too_many_arguments)]
fn denoise_loop(
&self,
unet: &stable_diffusion::unet_2d::UNet2DConditionModel,
text_embeddings: &Tensor,
sched: Scheduler,
latents: &mut Tensor,
guidance: f64,
cfg_plus: bool,
steps: u32,
start_step: usize,
inpaint_ctx: Option<&crate::img_utils::InpaintContext>,
) -> Result<()> {
let use_cfg = cfg_active(guidance);
let mut scheduler = crate::scheduler::build_scheduler(
sched,
steps as usize,
PredictionType::Epsilon,
self.is_turbo,
)?;
let timesteps = scheduler.timesteps().to_vec();
let active_timesteps = ×teps[start_step..];
let cfg_plus_schedule = if cfg_plus && use_cfg && matches!(sched, Scheduler::Ddim) {
Some(DdimAlphaSchedule::from_default(steps as usize))
} else {
if cfg_plus && !use_cfg {
tracing::warn!(
guidance,
"cfg_plus requested but cfg_scale ≈ 1.0 — falling back to standard step (no uncond available)"
);
} else if cfg_plus {
tracing::warn!(
scheduler = ?sched,
"cfg_plus requested but only DDIM is supported on SDXL/SD1.5 — falling back to standard step. Re-run with `--scheduler ddim` to enable CFG++."
);
}
None
};
let denoise_label = format!("Denoising ({} steps)", active_timesteps.len());
self.base.progress.stage_start(&denoise_label);
let denoise_start = Instant::now();
for (step_idx, &t) in active_timesteps.iter().enumerate() {
let step_start = std::time::Instant::now();
let latent_input = if use_cfg {
Tensor::cat(&[&*latents, &*latents], 0)?
} else {
latents.clone()
};
let latent_input = scheduler.scale_model_input(latent_input, t)?;
let noise_pred = unet.forward(&latent_input, t as f64, text_embeddings)?;
let (noise_pred_blended, noise_pred_uncond_opt) = if use_cfg {
let chunks = noise_pred.chunk(2, 0)?;
let noise_pred_uncond = chunks[0].clone();
let noise_pred_cond = &chunks[1];
let blended =
(&noise_pred_uncond + ((noise_pred_cond - &noise_pred_uncond)? * guidance)?)?;
(blended, Some(noise_pred_uncond))
} else {
(noise_pred, None)
};
*latents = match (cfg_plus_schedule.as_ref(), noise_pred_uncond_opt.as_ref()) {
(Some(ddim_sched), Some(eps_uncond)) => {
ddim_sched.cfg_plus_step(&*latents, &noise_pred_blended, eps_uncond, t)?
}
_ => scheduler.step(&noise_pred_blended, t, &*latents)?,
};
if let Some(ctx) = inpaint_ctx {
let noised_original =
scheduler.add_noise(&ctx.original_latents, ctx.noise.clone(), t)?;
*latents = crate::img2img::blend_inpaint_latents(&*latents, ctx, &noised_original)?;
}
self.base.progress.emit(ProgressEvent::DenoiseStep {
step: step_idx + 1,
total: active_timesteps.len(),
elapsed: step_start.elapsed(),
});
}
self.base
.progress
.stage_done(&denoise_label, denoise_start.elapsed());
Ok(())
}
#[allow(clippy::too_many_arguments)]
fn prepare_img2img_latents(
&self,
vae: &stable_diffusion::vae::AutoEncoderKL,
source_bytes: &[u8],
width: u32,
height: u32,
strength: f64,
steps: u32,
sched: Scheduler,
seed: u64,
device: &Device,
dtype: DType,
vae_dtype: DType,
) -> Result<(Tensor, usize, Tensor, Tensor)> {
use crate::img_utils::{decode_source_image, NormalizeRange};
let vae_scale = if self.is_turbo {
VAE_SCALE_TURBO
} else {
VAE_SCALE_STANDARD
};
let cache_key = image_size_cache_key(source_bytes, width, height);
let (encoded, cache_hit) = get_or_insert_cached_tensor(
&self.source_latent_cache,
cache_key,
device,
dtype,
|| {
self.base
.progress
.stage_start("Encoding source image (VAE)");
let encode_start = Instant::now();
let source_tensor = decode_source_image(
source_bytes,
width,
height,
NormalizeRange::MinusOneToOne,
device,
vae_dtype,
)?;
let encoded = vae.encode(&source_tensor)?;
let encoded = (encoded.mode()? * vae_scale)?;
let encoded = encoded.to_dtype(dtype)?;
self.base
.progress
.stage_done("Encoding source image (VAE)", encode_start.elapsed());
Ok(encoded)
},
)?;
if cache_hit {
self.base.progress.cache_hit("source image latents");
}
let start_step = crate::img2img::img2img_start_index(steps as usize, strength);
let scheduler = crate::scheduler::build_scheduler(
sched,
steps as usize,
PredictionType::Epsilon,
self.is_turbo,
)?;
let timesteps = scheduler.timesteps().to_vec();
let latent_h = height as usize / 8;
let latent_w = width as usize / 8;
let noise =
crate::engine::seeded_randn(seed, &[1, 4, latent_h, latent_w], device, DType::F32)?;
let noise = noise.to_dtype(dtype)?;
let noised = if start_step < timesteps.len() {
scheduler.add_noise(&encoded, noise.clone(), timesteps[start_step])?
} else {
encoded.clone()
};
tracing::info!(
start_step,
total_steps = steps,
strength,
"img2img: starting from step {start_step}"
);
Ok((noised, start_step, encoded, noise))
}
#[allow(clippy::too_many_arguments)]
fn encode_prompt(
&self,
clip_l: &stable_diffusion::clip::ClipTextTransformer,
clip_g: &stable_diffusion::clip::ClipTextTransformer,
tokenizer_l: &tokenizers::Tokenizer,
tokenizer_g: &tokenizers::Tokenizer,
prompt: &str,
negative_prompt: &str,
max_len: usize,
device: &Device,
clip_device: &Device,
dtype: DType,
guidance: f64,
) -> Result<Tensor> {
let cache_key = cfg_prompt_cache_key(prompt, negative_prompt, guidance);
let (text_embeddings, cache_hit) =
get_or_insert_cached_tensor(&self.prompt_cache, cache_key, device, dtype, || {
let use_cfg = cfg_active(guidance);
self.base.progress.stage_start("Encoding prompt (CLIP-L)");
let encode_l_start = Instant::now();
let tokens_l = Self::tokenize(tokenizer_l, prompt, max_len, clip_device)?;
let text_emb_l = clip_l.forward(&tokens_l)?;
self.base
.progress
.stage_done("Encoding prompt (CLIP-L)", encode_l_start.elapsed());
self.base.progress.stage_start("Encoding prompt (CLIP-G)");
let encode_g_start = Instant::now();
let tokens_g = Self::tokenize(tokenizer_g, prompt, max_len, clip_device)?;
let text_emb_g = clip_g.forward(&tokens_g)?;
self.base
.progress
.stage_done("Encoding prompt (CLIP-G)", encode_g_start.elapsed());
let text_embeddings = Tensor::cat(&[&text_emb_l, &text_emb_g], D::Minus1)?;
let text_embeddings = if use_cfg {
let uncond_tokens_l =
Self::tokenize(tokenizer_l, negative_prompt, max_len, clip_device)?;
let uncond_emb_l = clip_l.forward(&uncond_tokens_l)?;
let uncond_tokens_g =
Self::tokenize(tokenizer_g, negative_prompt, max_len, clip_device)?;
let uncond_emb_g = clip_g.forward(&uncond_tokens_g)?;
let uncond_embeddings =
Tensor::cat(&[&uncond_emb_l, &uncond_emb_g], D::Minus1)?;
Tensor::cat(&[&uncond_embeddings, &text_embeddings], 0)?
} else {
text_embeddings
};
let text_embeddings = text_embeddings.to_device(device)?;
Ok(text_embeddings.to_dtype(dtype)?)
})?;
if cache_hit {
self.base.progress.cache_hit("prompt conditioning");
return Ok(text_embeddings);
}
Ok(text_embeddings)
}
fn cached_mask(
&self,
mask_bytes: &[u8],
latent_h: usize,
latent_w: usize,
device: &Device,
dtype: DType,
) -> Result<Tensor> {
let key = latent_size_cache_key(mask_bytes, latent_h, latent_w);
let (mask, cache_hit) =
get_or_insert_cached_tensor(&self.mask_cache, key, device, dtype, || {
crate::img_utils::decode_mask_image(mask_bytes, latent_h, latent_w, device, dtype)
})?;
if cache_hit {
self.base.progress.cache_hit("inpaint mask");
return Ok(mask);
}
Ok(mask)
}
fn generate_sequential(&mut self, req: &GenerateRequest) -> Result<GenerateResponse> {
let (clip_encoder, clip_tokenizer, clip_encoder_2, clip_tokenizer_2) =
self.validate_paths()?;
if let Some(warning) = check_memory_budget(&self.base.paths, LoadStrategy::Sequential) {
self.base.progress.info(&warning);
}
let device = crate::device::create_device(self.base.gpu_ordinal, &self.base.progress)?;
let dtype = if crate::device::is_gpu(&device) {
DType::F16
} else {
DType::F32
};
let sd_config = self.sd_config();
let max_len = sd_config.clip.max_position_embeddings;
let start = Instant::now();
let seed = req.seed.unwrap_or_else(rand_seed);
let width = req.width as usize;
let height = req.height as usize;
let guidance = req.guidance;
tracing::info!(
prompt = %req.prompt,
seed, width, height,
steps = req.steps,
guidance,
"starting sequential SDXL generation"
);
self.base
.progress
.info("Using sequential loading (load-use-drop) to minimize peak memory");
let neg = req.negative_prompt.as_deref().unwrap_or("");
let cache_key = cfg_prompt_cache_key(&req.prompt, neg, guidance);
let text_embeddings = if let Some(tensor) =
restore_cached_tensor(&self.prompt_cache, &cache_key, &device, dtype)?
{
self.base.progress.cache_hit("prompt conditioning");
tensor
} else {
if let Some(status) = memory_status_string() {
self.base.progress.info(&status);
}
let tokenizer_l = self.load_clip_tokenizer(&clip_tokenizer, "CLIP-L")?;
let tokenizer_g = self.load_clip_tokenizer(&clip_tokenizer_2, "CLIP-G")?;
let tier1 = self
.pending_placement
.as_ref()
.map(|p| p.text_encoders)
.unwrap_or_default();
let clip_device = crate::device::resolve_device(Some(tier1), || Ok(device.clone()))?;
let (clip_l, clip_g) =
if let Some(single_file) = self.single_file_path.clone() {
let remap = Self::load_sdxl_remap(&single_file)?;
self.base
.progress
.stage_start("Loading CLIP-L (single-file)");
let clip_l_start = Instant::now();
let clip_l = Self::build_clip_l_single_file(
&single_file,
&remap,
&sd_config.clip,
&clip_device,
)?;
self.base
.progress
.stage_done("Loading CLIP-L (single-file)", clip_l_start.elapsed());
self.base
.progress
.stage_start("Loading CLIP-G (single-file)");
let clip_g_start = Instant::now();
let clip2_config = sd_config.clip2.as_ref().ok_or_else(|| {
anyhow::anyhow!("SDXL config missing clip2 configuration")
})?;
let clip_g = Self::build_clip_g_single_file(
&single_file,
&remap,
clip2_config,
&clip_device,
)?;
self.base
.progress
.stage_done("Loading CLIP-G (single-file)", clip_g_start.elapsed());
(clip_l, clip_g)
} else {
self.base.progress.stage_start("Loading CLIP-L encoder");
let clip_l_start = Instant::now();
let clip_l = stable_diffusion::build_clip_transformer(
&sd_config.clip,
&clip_encoder,
&clip_device,
DType::F32,
)?;
self.base
.progress
.stage_done("Loading CLIP-L encoder", clip_l_start.elapsed());
self.base.progress.stage_start("Loading CLIP-G encoder");
let clip_g_start = Instant::now();
let clip2_config = sd_config.clip2.as_ref().ok_or_else(|| {
anyhow::anyhow!("SDXL config missing clip2 configuration")
})?;
let clip_g = stable_diffusion::build_clip_transformer(
clip2_config,
&clip_encoder_2,
&clip_device,
DType::F32,
)?;
self.base
.progress
.stage_done("Loading CLIP-G encoder", clip_g_start.elapsed());
(clip_l, clip_g)
};
let text_embeddings = self.encode_prompt(
&clip_l,
&clip_g,
&tokenizer_l,
&tokenizer_g,
&req.prompt,
neg,
max_len,
&device,
&clip_device,
dtype,
guidance,
)?;
drop(clip_l);
drop(clip_g);
self.base.progress.info("Freed CLIP-L and CLIP-G encoders");
tracing::info!("CLIP encoders dropped (sequential mode)");
text_embeddings
};
let unet_size = std::fs::metadata(&self.base.paths.transformer)
.map(|m| m.len())
.unwrap_or(0);
let unet_batch = if req.guidance > 1.0 { 2 } else { 1 };
let unet_activation_budget = crate::device::activation_bytes(
req.width,
req.height,
unet_batch,
crate::device::dtype_bytes(dtype),
crate::device::ActivationFamily::SdxlUnet,
);
preflight_memory_check("UNet", unet_size, unet_activation_budget)?;
if let Some(status) = memory_status_string() {
self.base.progress.info(&status);
}
self.base.progress.stage_start("Loading UNet (GPU)");
let unet_start = Instant::now();
let unet = self.build_unet_for_strategy(&sd_config, &device, dtype)?;
self.base
.progress
.stage_done("Loading UNet (GPU)", unet_start.elapsed());
let sched = req.scheduler.unwrap_or(self.scheduler);
let is_img2img = req.source_image.is_some();
let (mut latents, start_step, inpaint_ctx) = if let Some(ref source_bytes) =
req.source_image
{
self.base
.progress
.info("img2img mode: encoding source image before denoising");
self.base.progress.stage_start("Loading VAE (GPU)");
let vae_start_t = Instant::now();
let vae_dtype = resolve_sdxl_vae_dtype(dtype, self.single_file_path.is_some());
let vae = self.build_vae_for_strategy(&sd_config, &device, vae_dtype)?;
self.base
.progress
.stage_done("Loading VAE (GPU)", vae_start_t.elapsed());
let (latents, start_step, encoded, noise) = self.prepare_img2img_latents(
&vae,
source_bytes,
req.width,
req.height,
req.strength,
req.steps,
sched,
seed,
&device,
dtype,
vae_dtype,
)?;
let inpaint_ctx = if let Some(ref mask_bytes) = req.mask_image {
let mask = self.cached_mask(mask_bytes, height / 8, width / 8, &device, dtype)?;
Some(crate::img_utils::InpaintContext {
original_latents: encoded,
mask,
noise,
})
} else {
None
};
drop(vae);
self.base
.progress
.info("Freed VAE (will reload for decode)");
device.synchronize()?;
(latents, start_step, inpaint_ctx)
} else {
let latent_h = height / 8;
let latent_w = width / 8;
let init_scheduler = crate::scheduler::build_scheduler(
sched,
req.steps as usize,
PredictionType::Epsilon,
self.is_turbo,
)?;
let init_noise_sigma = init_scheduler.init_noise_sigma();
drop(init_scheduler);
let latents = (crate::engine::seeded_randn(
seed,
&[1, 4, latent_h, latent_w],
&device,
DType::F32,
)? * init_noise_sigma)?;
(latents.to_dtype(dtype)?, 0, None)
};
self.denoise_loop(
&unet,
&text_embeddings,
sched,
&mut latents,
guidance,
resolve_cfg_plus(req),
req.steps,
start_step,
inpaint_ctx.as_ref(),
)?;
drop(inpaint_ctx);
drop(unet);
drop(text_embeddings);
device.synchronize()?;
self.base.progress.info("Freed UNet");
tracing::info!("UNet dropped (sequential mode)");
let vae_load_label = if is_img2img {
"Reloading VAE (GPU)"
} else {
"Loading VAE (GPU)"
};
self.base.progress.stage_start(vae_load_label);
let vae_start = Instant::now();
let vae_dtype = resolve_sdxl_vae_dtype(dtype, self.single_file_path.is_some());
let vae = self.build_vae_for_strategy(&sd_config, &device, vae_dtype)?;
self.base
.progress
.stage_done(vae_load_label, vae_start.elapsed());
self.base.progress.stage_start("VAE decode");
let vae_decode_start = Instant::now();
let vae_scale = if self.is_turbo {
VAE_SCALE_TURBO
} else {
VAE_SCALE_STANDARD
};
let latents = (latents / vae_scale)?;
let latents_for_vae = latents.to_dtype(vae_dtype)?;
let device_for_sync = device.clone();
let img = crate::vae_tiling::decode_with_oom_fallback(
&latents_for_vae,
|t| vae.decode(t).map_err(Into::into),
|| {
if let Err(e) = device_for_sync.synchronize() {
tracing::warn!(
"SDXL (sequential) device.synchronize() after VAE OOM failed: {e}"
);
}
},
)?;
let img = ((img / 2.)? + 0.5)?.clamp(0f32, 1f32)?;
let img = (img * 255.)?.to_dtype(DType::U8)?;
let img = img.squeeze(0)?;
self.base
.progress
.stage_done("VAE decode", vae_decode_start.elapsed());
let output_metadata = build_output_metadata(req, seed, Some(sched));
let image_bytes = encode_image(
&img,
req.resolved_output_format(),
req.width,
req.height,
output_metadata.as_ref(),
)?;
let generation_time_ms = start.elapsed().as_millis() as u64;
tracing::info!(
generation_time_ms,
seed,
"sequential SDXL generation complete"
);
Ok(GenerateResponse {
images: vec![ImageData {
data: image_bytes,
format: req.resolved_output_format(),
width: req.width,
height: req.height,
index: 0,
}],
generation_time_ms,
model: req.model.clone(),
seed_used: seed,
video: None,
gpu: None,
})
}
}
impl SDXLEngine {
fn generate_inner(&mut self, req: &GenerateRequest) -> Result<GenerateResponse> {
if self.base.load_strategy == LoadStrategy::Sequential {
return self.generate_sequential(req);
}
let requested_stack = lora_stack_fingerprint(&self.pending_loras);
if requested_stack != self.active_lora_fingerprint {
if let Some(loaded) = self.base.loaded.as_mut() {
if loaded.unet.is_some() {
loaded.unet = None;
loaded.device.synchronize()?;
tracing::info!("SDXL UNet dropped (LoRA stack changed)");
}
}
self.active_lora_fingerprint = requested_stack;
}
self.reload_unet_if_needed()?;
let loaded = self
.base
.loaded
.as_ref()
.ok_or_else(|| anyhow::anyhow!("model not loaded — call load() first"))?;
let start = Instant::now();
let seed = req.seed.unwrap_or_else(rand_seed);
let width = req.width as usize;
let height = req.height as usize;
let guidance = req.guidance;
tracing::info!(
prompt = %req.prompt,
seed, width, height,
steps = req.steps,
guidance,
scheduler = %self.scheduler,
"starting SDXL generation"
);
let max_len = loaded.sd_config.clip.max_position_embeddings;
let neg = req.negative_prompt.as_deref().unwrap_or("");
let text_embeddings = self.encode_prompt(
&loaded.clip_l,
&loaded.clip_g,
&loaded.tokenizer_l,
&loaded.tokenizer_g,
&req.prompt,
neg,
max_len,
&loaded.device,
&loaded.clip_device,
loaded.dtype,
guidance,
)?;
let sched = req.scheduler.unwrap_or(self.scheduler);
let (mut latents, start_step, inpaint_ctx) =
if let Some(ref source_bytes) = req.source_image {
let (latents, start_step, encoded, noise) = self.prepare_img2img_latents(
&loaded.vae,
source_bytes,
req.width,
req.height,
req.strength,
req.steps,
sched,
seed,
&loaded.device,
loaded.dtype,
loaded.vae_dtype,
)?;
let inpaint_ctx = if let Some(ref mask_bytes) = req.mask_image {
let mask = self.cached_mask(
mask_bytes,
height / 8,
width / 8,
&loaded.device,
loaded.dtype,
)?;
Some(crate::img_utils::InpaintContext {
original_latents: encoded,
mask,
noise,
})
} else {
None
};
(latents, start_step, inpaint_ctx)
} else {
let latent_h = height / 8;
let latent_w = width / 8;
let init_scheduler = crate::scheduler::build_scheduler(
sched,
req.steps as usize,
PredictionType::Epsilon,
self.is_turbo,
)?;
let init_noise_sigma = init_scheduler.init_noise_sigma();
drop(init_scheduler);
let latents = (crate::engine::seeded_randn(
seed,
&[1, 4, latent_h, latent_w],
&loaded.device,
DType::F32,
)? * init_noise_sigma)?;
(latents.to_dtype(loaded.dtype)?, 0, None)
};
let unet = loaded
.unet
.as_ref()
.ok_or_else(|| anyhow::anyhow!("UNet not loaded"))?;
self.denoise_loop(
unet,
&text_embeddings,
sched,
&mut latents,
guidance,
resolve_cfg_plus(req),
req.steps,
start_step,
inpaint_ctx.as_ref(),
)?;
drop(inpaint_ctx);
let _ = loaded;
let loaded = self.base.loaded.as_mut().unwrap();
loaded.unet = None;
loaded.device.synchronize()?;
tracing::info!("UNet dropped to free VRAM for VAE decode");
let _ = loaded;
let loaded = self.base.loaded.as_ref().unwrap();
self.base.progress.stage_start("VAE decode");
let vae_start = Instant::now();
let vae_scale = if self.is_turbo {
VAE_SCALE_TURBO
} else {
VAE_SCALE_STANDARD
};
let latents = (latents / vae_scale)?;
let latents_for_vae = latents.to_dtype(loaded.vae_dtype)?;
let vae = &loaded.vae;
let device_for_sync = loaded.device.clone();
let img = crate::vae_tiling::decode_with_oom_fallback(
&latents_for_vae,
|t| vae.decode(t).map_err(Into::into),
|| {
if let Err(e) = device_for_sync.synchronize() {
tracing::warn!(
"SDXL (parallel) device.synchronize() after VAE OOM failed: {e}"
);
}
},
)?;
let img = ((img / 2.)? + 0.5)?.clamp(0f32, 1f32)?;
let img = (img * 255.)?.to_dtype(DType::U8)?;
let img = img.squeeze(0)?;
self.base
.progress
.stage_done("VAE decode", vae_start.elapsed());
let output_metadata = build_output_metadata(req, seed, Some(sched));
let image_bytes = encode_image(
&img,
req.resolved_output_format(),
req.width,
req.height,
output_metadata.as_ref(),
)?;
let generation_time_ms = start.elapsed().as_millis() as u64;
tracing::info!(generation_time_ms, seed, "SDXL generation complete");
Ok(GenerateResponse {
images: vec![ImageData {
data: image_bytes,
format: req.resolved_output_format(),
width: req.width,
height: req.height,
index: 0,
}],
generation_time_ms,
model: req.model.clone(),
seed_used: seed,
video: None,
gpu: None,
})
}
}
impl InferenceEngine for SDXLEngine {
fn generate(&mut self, req: &GenerateRequest) -> Result<GenerateResponse> {
self.pending_placement = req.placement.clone();
self.pending_loras = super::lora::effective_sdxl_loras(req);
let result = self.generate_inner(req);
self.pending_placement = None;
self.pending_loras.clear();
result
}
fn model_name(&self) -> &str {
self.base.model_name()
}
fn is_loaded(&self) -> bool {
self.base.is_loaded()
}
fn load(&mut self) -> Result<()> {
SDXLEngine::load(self)
}
fn unload(&mut self) {
self.base.unload();
clear_cache(&self.prompt_cache);
clear_cache(&self.source_latent_cache);
clear_cache(&self.mask_cache);
self.active_lora_fingerprint.clear();
}
fn set_on_progress(&mut self, callback: ProgressCallback) {
self.base.set_on_progress(callback);
}
fn clear_on_progress(&mut self) {
self.base.clear_on_progress();
}
fn model_paths(&self) -> Option<&mold_core::ModelPaths> {
Some(&self.base.paths)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::engine::InferenceEngine;
use crate::shared_pool::SharedPool;
use safetensors::tensor::{serialize_to_file, Dtype as SafeDtype, TensorView};
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use tokenizers::models::bpe::BPE;
fn synth_sdxl_single_file(name: &str) -> PathBuf {
let path = std::env::temp_dir().join(format!(
"mold-sdxl-from-sf-{}-{}-{}.safetensors",
name,
std::process::id(),
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_nanos(),
));
let keys: &[&str] = &[
"model.diffusion_model.input_blocks.0.0.weight",
"model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight",
"first_stage_model.encoder.down.0.block.0.norm1.weight",
"first_stage_model.quant_conv.weight",
"conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.q_proj.weight",
"conditioner.embedders.0.transformer.text_model.final_layer_norm.weight",
"conditioner.embedders.1.model.transformer.resblocks.0.attn.in_proj_weight",
"conditioner.embedders.1.model.text_projection",
];
let f32_zero = 0.0f32.to_le_bytes().to_vec();
let buffers: Vec<Vec<u8>> = keys.iter().map(|_| f32_zero.clone()).collect();
let mut tensors: HashMap<String, TensorView<'_>> = HashMap::new();
for (key, buf) in keys.iter().zip(buffers.iter()) {
tensors.insert(
(*key).to_string(),
TensorView::new(SafeDtype::F32, vec![1], buf).unwrap(),
);
}
serialize_to_file(&tensors, &None, &path).unwrap();
path
}
#[test]
fn from_single_file_constructs_for_synthetic_sdxl_checkpoint() {
let single_file = synth_sdxl_single_file("ok");
let clip_l_tok = std::env::temp_dir().join("mold-sdxl-clip-l-stub.json");
let clip_g_tok = std::env::temp_dir().join("mold-sdxl-clip-g-stub.json");
let engine = SDXLEngine::from_single_file(
"juggernaut-xl-v9".to_string(),
single_file.clone(),
clip_l_tok,
clip_g_tok,
Scheduler::default(),
false,
LoadStrategy::Eager,
0,
None,
)
.expect("constructor must accept a valid SDXL single-file layout");
assert_eq!(engine.model_name(), "juggernaut-xl-v9");
assert_eq!(
engine.single_file_path.as_deref(),
Some(single_file.as_path()),
"single-file path must be stashed for the future load() branch",
);
assert!(
!engine.is_loaded(),
"constructor must not eagerly materialise model weights",
);
let _ = std::fs::remove_file(single_file);
}
#[test]
fn sdxl_loads_clip_tokenizers_through_shared_pool() {
let dir = tempfile::tempdir().unwrap();
let clip_l_tokenizer = dir.path().join("clip-l-tokenizer.json");
let clip_g_tokenizer = dir.path().join("clip-g-tokenizer.json");
tokenizers::Tokenizer::new(BPE::default())
.save(&clip_l_tokenizer, false)
.unwrap();
tokenizers::Tokenizer::new(BPE::default())
.save(&clip_g_tokenizer, false)
.unwrap();
let weights_path = dir.path().join("weights.safetensors");
std::fs::write(&weights_path, b"stub").unwrap();
let shared_pool = Arc::new(Mutex::new(SharedPool::new()));
let pooled_l = shared_pool
.lock()
.unwrap()
.load_tokenizer(&clip_l_tokenizer)
.unwrap();
let pooled_g = shared_pool
.lock()
.unwrap()
.load_tokenizer(&clip_g_tokenizer)
.unwrap();
let paths = ModelPaths {
transformer: weights_path.clone(),
transformer_shards: Vec::new(),
vae: weights_path.clone(),
spatial_upscaler: None,
temporal_upscaler: None,
distilled_lora: None,
t5_encoder: None,
clip_encoder: Some(weights_path.clone()),
t5_tokenizer: None,
clip_tokenizer: Some(clip_l_tokenizer.clone()),
clip_encoder_2: Some(weights_path),
clip_tokenizer_2: Some(clip_g_tokenizer.clone()),
text_encoder_files: Vec::new(),
text_tokenizer: None,
decoder: None,
};
let engine = SDXLEngine::new(
"sdxl-test".to_string(),
paths,
Scheduler::default(),
false,
LoadStrategy::Eager,
0,
Some(shared_pool),
);
let loaded_l = engine
.load_clip_tokenizer(&clip_l_tokenizer, "CLIP-L")
.unwrap();
let loaded_g = engine
.load_clip_tokenizer(&clip_g_tokenizer, "CLIP-G")
.unwrap();
assert!(Arc::ptr_eq(&pooled_l, &loaded_l));
assert!(Arc::ptr_eq(&pooled_g, &loaded_g));
}
#[test]
fn sdxl_loads_vae_tensors_through_shared_pool() {
let dir = tempfile::tempdir().unwrap();
let vae_path = dir.path().join("vae.safetensors");
let weight = 1.0f32.to_le_bytes();
let mut tensors = HashMap::new();
tensors.insert(
"encoder.conv_in.weight".to_string(),
TensorView::new(SafeDtype::F32, vec![1], &weight).unwrap(),
);
serialize_to_file(&tensors, &None, &vae_path).unwrap();
let shared_pool = Arc::new(Mutex::new(SharedPool::new()));
let pooled = shared_pool
.lock()
.unwrap()
.load_safetensors_cpu_tensors(std::slice::from_ref(&vae_path))
.unwrap()
.unwrap();
let paths = ModelPaths {
transformer: dir.path().join("unet.safetensors"),
transformer_shards: Vec::new(),
vae: vae_path.clone(),
spatial_upscaler: None,
temporal_upscaler: None,
distilled_lora: None,
t5_encoder: None,
clip_encoder: Some(dir.path().join("clip-l.safetensors")),
t5_tokenizer: None,
clip_tokenizer: Some(dir.path().join("clip-l-tokenizer.json")),
clip_encoder_2: Some(dir.path().join("clip-g.safetensors")),
clip_tokenizer_2: Some(dir.path().join("clip-g-tokenizer.json")),
text_encoder_files: Vec::new(),
text_tokenizer: None,
decoder: None,
};
let engine = SDXLEngine::new(
"sdxl-test".to_string(),
paths,
Scheduler::default(),
false,
LoadStrategy::Eager,
0,
Some(shared_pool),
);
let loaded = engine.load_vae_cpu_tensors().unwrap().unwrap();
assert!(Arc::ptr_eq(&pooled, &loaded));
}
#[test]
fn from_single_file_rejects_missing_file() {
let bogus = std::env::temp_dir().join(format!(
"mold-sdxl-from-sf-missing-{}-{}.safetensors",
std::process::id(),
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_nanos(),
));
let result = SDXLEngine::from_single_file(
"missing".to_string(),
bogus,
std::env::temp_dir().join("mold-sdxl-clip-l-stub.json"),
std::env::temp_dir().join("mold-sdxl-clip-g-stub.json"),
Scheduler::default(),
false,
LoadStrategy::Eager,
0,
None,
);
assert!(
result.is_err(),
"constructor must surface a missing-file error before deeper parsing",
);
}
#[test]
fn load_branches_to_single_file_path_and_invokes_candle_constructors() {
let single_file = synth_sdxl_single_file("load-branch");
let make_stub = |label: &str| -> PathBuf {
let path = std::env::temp_dir().join(format!(
"mold-sdxl-{}-stub-{}-{}.json",
label,
std::process::id(),
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_nanos(),
));
std::fs::write(&path, b"").unwrap();
path
};
let clip_l_tok = make_stub("clip-l");
let clip_g_tok = make_stub("clip-g");
let mut engine = SDXLEngine::from_single_file(
"juggernaut-xl-v9".to_string(),
single_file.clone(),
clip_l_tok.clone(),
clip_g_tok.clone(),
Scheduler::Ddim,
false,
LoadStrategy::Eager,
0,
None,
)
.expect("constructor");
std::env::set_var("MOLD_DEVICE", "cpu");
let err = SDXLEngine::load(&mut engine)
.expect_err("synthetic checkpoint can't satisfy SDXL's full tensor set");
std::env::remove_var("MOLD_DEVICE");
let msg = err.to_string();
assert!(
msg.contains("single-file") || msg.contains("rename rule"),
"expected error from the single-file load layer, got: {msg}",
);
let _ = std::fs::remove_file(single_file);
let _ = std::fs::remove_file(clip_l_tok);
let _ = std::fs::remove_file(clip_g_tok);
}
#[test]
#[ignore]
fn from_single_file_real_shape_load_smoke() {
}
#[test]
fn build_clip_l_single_file_dispatches_through_backend_not_diffusers_loader() {
let single_file = synth_sdxl_single_file("seq-clip-l-dispatch");
let remap = SDXLEngine::load_sdxl_remap(&single_file).expect("remap");
let result = SDXLEngine::build_clip_l_single_file(
&single_file,
&remap,
&stable_diffusion::clip::Config::sdxl(),
&Device::Cpu,
);
let err = result.expect_err(
"synthetic CLIP-L is missing token_embedding / position_embedding / \
every encoder layer beyond layer 0 — construction must fail",
);
let msg = err.to_string();
assert!(
!msg.contains("cannot find tensor text_model"),
"expected failure from the SingleFileBackend layer (e.g. 'no rename rule \
for diffusers key text_model.embeddings.token_embedding.weight'); got the \
diffusers `from_mmaped_safetensors` error instead — sequential dispatch \
is still routing through `build_clip_transformer`. Got: {msg}",
);
let _ = std::fs::remove_file(single_file);
}
#[test]
fn build_clip_g_single_file_dispatches_through_backend_not_diffusers_loader() {
let single_file = synth_sdxl_single_file("seq-clip-g-dispatch");
let remap = SDXLEngine::load_sdxl_remap(&single_file).expect("remap");
let result = SDXLEngine::build_clip_g_single_file(
&single_file,
&remap,
&stable_diffusion::clip::Config::sdxl2(),
&Device::Cpu,
);
let err = result.expect_err("synthetic CLIP-G is incomplete");
let msg = err.to_string();
assert!(
!msg.contains("cannot find tensor text_model"),
"expected failure from SingleFileBackend, not diffusers loader. Got: {msg}",
);
let _ = std::fs::remove_file(single_file);
}
#[test]
fn build_unet_single_file_dispatches_through_backend_not_diffusers_loader() {
let single_file = synth_sdxl_single_file("seq-unet-dispatch");
let remap = SDXLEngine::load_sdxl_remap(&single_file).expect("remap");
let result = SDXLEngine::build_unet_single_file(
&single_file,
&remap,
&stable_diffusion::StableDiffusionConfig::sdxl(None, None, None),
&Device::Cpu,
DType::F32,
);
let err = result.expect_err("synthetic UNet is incomplete");
let msg = err.to_string();
assert!(
!msg.contains("cannot find tensor conv_in"),
"expected failure from SingleFileBackend, not diffusers loader. Got: {msg}",
);
let _ = std::fs::remove_file(single_file);
}
#[test]
fn from_single_file_threads_is_turbo_true() {
let single_file = synth_sdxl_single_file("turbo");
let clip_l_tok = std::env::temp_dir().join("mold-sdxl-turbo-clip-l-stub.json");
let clip_g_tok = std::env::temp_dir().join("mold-sdxl-turbo-clip-g-stub.json");
let engine = SDXLEngine::from_single_file(
"sdxl-turbo:fp16".to_string(),
single_file.clone(),
clip_l_tok,
clip_g_tok,
Scheduler::EulerAncestral,
true,
LoadStrategy::Eager,
0,
None,
)
.expect("constructor must accept is_turbo = true");
assert!(
engine.is_turbo,
"is_turbo arg must thread into the engine field — sdxl_config() reads this for VAE_SCALE_TURBO",
);
let _ = std::fs::remove_file(single_file);
}
#[test]
fn from_single_file_threads_is_turbo_false() {
let single_file = synth_sdxl_single_file("standard");
let clip_l_tok = std::env::temp_dir().join("mold-sdxl-std-clip-l-stub.json");
let clip_g_tok = std::env::temp_dir().join("mold-sdxl-std-clip-g-stub.json");
let engine = SDXLEngine::from_single_file(
"sdxl-base:fp16".to_string(),
single_file.clone(),
clip_l_tok,
clip_g_tok,
Scheduler::Ddim,
false,
LoadStrategy::Eager,
0,
None,
)
.expect("constructor must accept is_turbo = false");
assert!(
!engine.is_turbo,
"is_turbo = false must produce a standard-config engine",
);
let _ = std::fs::remove_file(single_file);
}
#[test]
fn single_file_sdxl_vae_defaults_to_f32_to_avoid_black_finetune_decodes() {
unsafe { std::env::remove_var("MOLD_VAE_DTYPE") };
assert_eq!(resolve_sdxl_vae_dtype(DType::F16, true), DType::F32);
assert_eq!(resolve_sdxl_vae_dtype(DType::F16, false), DType::F16);
}
#[test]
fn test_cfg_disabled_at_guidance_1_0() {
assert!(!cfg_active(1.0));
}
#[test]
fn test_cfg_disabled_just_below_1_0() {
assert!(!cfg_active(1.0 - 1e-5));
}
#[test]
fn test_cfg_enabled_at_guidance_1_5() {
assert!(cfg_active(1.5));
}
#[test]
fn test_cfg_enabled_at_guidance_7_5() {
assert!(cfg_active(7.5));
}
#[test]
fn lora_stack_fingerprint_equality_drives_unet_drop() {
let a = mold_core::LoraWeight {
path: "/x.safetensors".into(),
scale: 0.8,
};
let b = mold_core::LoraWeight {
path: "/y.safetensors".into(),
scale: 0.4,
};
let same_a = mold_core::LoraWeight {
path: "/x.safetensors".into(),
scale: 0.8,
};
assert_eq!(
lora_stack_fingerprint(&[a.clone(), b.clone()]),
lora_stack_fingerprint(&[same_a.clone(), b.clone()])
);
let scaled = mold_core::LoraWeight {
path: "/x.safetensors".into(),
scale: 0.9,
};
assert_ne!(
lora_stack_fingerprint(std::slice::from_ref(&a)),
lora_stack_fingerprint(std::slice::from_ref(&scaled))
);
assert_ne!(
lora_stack_fingerprint(&[a.clone(), b.clone()]),
lora_stack_fingerprint(&[b, a])
);
}
#[test]
fn sdxl_prompt_cache_distinguishes_negative_prompt_changes() {
use crate::cache::{cfg_prompt_cache_key, store_cached_tensor};
let cache: Mutex<LruCache<CfgPromptCacheKey, CachedTensor>> =
Mutex::new(LruCache::new(DEFAULT_PROMPT_CACHE_CAPACITY));
let device = Device::Cpu;
let dtype = DType::F32;
let embeddings = candle_core::Tensor::zeros((1, 4), dtype, &device).unwrap();
let key_a = cfg_prompt_cache_key("a cat", "blurry", 7.0);
store_cached_tensor(&cache, key_a.clone(), &embeddings).unwrap();
let key_b = cfg_prompt_cache_key("a cat", "low quality", 7.0);
let restored = restore_cached_tensor(&cache, &key_b, &device, dtype).unwrap();
assert!(
restored.is_none(),
"different negative prompt must miss the cache (silent-wrong-output bug)"
);
let restored = restore_cached_tensor(&cache, &key_a, &device, dtype).unwrap();
assert!(
restored.is_some(),
"identical (pos, neg, guidance) must hit"
);
}
}