1use anyhow::{bail, Result};
2use candle_core::{DType, Device, Module, Tensor, D};
3use candle_transformers::models::stable_diffusion;
4use candle_transformers::models::stable_diffusion::schedulers::PredictionType;
5use mold_core::{GenerateRequest, GenerateResponse, ImageData, ModelPaths, Scheduler};
6use std::collections::HashMap;
7use std::path::{Path, PathBuf};
8use std::sync::{Arc, Mutex};
9use std::time::Instant;
10
11use crate::cache::{
12 cfg_prompt_cache_key, clear_cache, get_or_insert_cached_tensor, image_size_cache_key,
13 latent_size_cache_key, restore_cached_tensor, CachedTensor, CfgPromptCacheKey,
14 ImageSizeCacheKey, LatentSizeCacheKey, LruCache, DEFAULT_IMAGE_CACHE_CAPACITY,
15 DEFAULT_PROMPT_CACHE_CAPACITY,
16};
17use crate::cfg_plus_ddim::DdimAlphaSchedule;
18use crate::device::{check_memory_budget, memory_status_string, preflight_memory_check};
19use crate::engine::{cfg_active, rand_seed, resolve_cfg_plus, InferenceEngine, LoadStrategy};
20use crate::engine_base::EngineBase;
21use crate::image::{build_output_metadata, encode_image};
22use crate::progress::{ProgressCallback, ProgressEvent};
23
24struct LoadedSDXL {
26 unet: Option<stable_diffusion::unet_2d::UNet2DConditionModel>,
28 vae: stable_diffusion::vae::AutoEncoderKL,
29 clip_l: stable_diffusion::clip::ClipTextTransformer,
30 clip_g: stable_diffusion::clip::ClipTextTransformer,
31 tokenizer_l: Arc<tokenizers::Tokenizer>,
32 tokenizer_g: Arc<tokenizers::Tokenizer>,
33 sd_config: stable_diffusion::StableDiffusionConfig,
34 device: Device,
35 clip_device: Device,
37 dtype: DType,
38 vae_dtype: DType,
42}
43
44pub struct SDXLEngine {
46 base: EngineBase<LoadedSDXL>,
47 scheduler: Scheduler,
48 is_turbo: bool,
49 shared_pool: Option<Arc<Mutex<crate::shared_pool::SharedPool>>>,
50 prompt_cache: Mutex<LruCache<CfgPromptCacheKey, CachedTensor>>,
51 source_latent_cache: Mutex<LruCache<ImageSizeCacheKey, CachedTensor>>,
52 mask_cache: Mutex<LruCache<LatentSizeCacheKey, CachedTensor>>,
53 pending_placement: Option<mold_core::types::DevicePlacement>,
54 pub(crate) single_file_path: Option<PathBuf>,
61 pending_loras: Vec<mold_core::LoraWeight>,
66 active_lora_fingerprint: Vec<(String, u64)>,
73}
74
75fn lora_stack_fingerprint(loras: &[mold_core::LoraWeight]) -> Vec<(String, u64)> {
80 loras
81 .iter()
82 .map(|w| (w.path.clone(), w.scale.to_bits()))
83 .collect()
84}
85
86const VAE_SCALE_STANDARD: f64 = 0.18215;
88const VAE_SCALE_TURBO: f64 = 0.13025;
90
91fn resolve_sdxl_vae_dtype(default_dtype: DType, single_file: bool) -> DType {
92 let default = if single_file {
93 DType::F32
97 } else {
98 default_dtype
99 };
100 crate::device::resolve_vae_dtype(default)
101}
102
103impl SDXLEngine {
104 pub fn new(
105 model_name: String,
106 paths: ModelPaths,
107 scheduler: Scheduler,
108 is_turbo: bool,
109 load_strategy: LoadStrategy,
110 gpu_ordinal: usize,
111 shared_pool: Option<Arc<Mutex<crate::shared_pool::SharedPool>>>,
112 ) -> Self {
113 Self {
114 base: EngineBase::new(model_name, paths, load_strategy, gpu_ordinal),
115 scheduler,
116 is_turbo,
117 shared_pool,
118 prompt_cache: Mutex::new(LruCache::new(DEFAULT_PROMPT_CACHE_CAPACITY)),
119 source_latent_cache: Mutex::new(LruCache::new(DEFAULT_IMAGE_CACHE_CAPACITY)),
120 mask_cache: Mutex::new(LruCache::new(DEFAULT_IMAGE_CACHE_CAPACITY)),
121 pending_placement: None,
122 single_file_path: None,
123 pending_loras: Vec::new(),
124 active_lora_fingerprint: Vec::new(),
125 }
126 }
127
128 #[allow(clippy::too_many_arguments)]
154 pub fn from_single_file(
155 model_name: String,
156 single_file_path: PathBuf,
157 clip_l_tokenizer: PathBuf,
158 clip_g_tokenizer: PathBuf,
159 scheduler: Scheduler,
160 is_turbo: bool,
161 load_strategy: LoadStrategy,
162 gpu_ordinal: usize,
163 shared_pool: Option<Arc<Mutex<crate::shared_pool::SharedPool>>>,
164 ) -> Result<Self> {
165 if !single_file_path.exists() {
166 bail!(
167 "single-file checkpoint not found: {}",
168 single_file_path.display()
169 );
170 }
171
172 let bundle = crate::loader::single_file::load(
175 &single_file_path,
176 mold_catalog::families::Family::Sdxl,
177 )?;
178
179 let _remap = crate::loader::sdxl_keys::build_sdxl_remap(&bundle)?;
186
187 let paths = ModelPaths {
193 transformer: single_file_path.clone(),
194 transformer_shards: Vec::new(),
195 vae: single_file_path.clone(),
196 spatial_upscaler: None,
197 temporal_upscaler: None,
198 distilled_lora: None,
199 t5_encoder: None,
200 clip_encoder: Some(single_file_path.clone()),
201 t5_tokenizer: None,
202 clip_tokenizer: Some(clip_l_tokenizer),
203 clip_encoder_2: Some(single_file_path.clone()),
204 clip_tokenizer_2: Some(clip_g_tokenizer),
205 text_encoder_files: Vec::new(),
206 text_tokenizer: None,
207 decoder: None,
208 };
209
210 Ok(Self {
218 base: EngineBase::new(model_name, paths, load_strategy, gpu_ordinal),
219 scheduler,
220 is_turbo,
221 shared_pool,
222 prompt_cache: Mutex::new(LruCache::new(DEFAULT_PROMPT_CACHE_CAPACITY)),
223 source_latent_cache: Mutex::new(LruCache::new(DEFAULT_IMAGE_CACHE_CAPACITY)),
224 mask_cache: Mutex::new(LruCache::new(DEFAULT_IMAGE_CACHE_CAPACITY)),
225 pending_placement: None,
226 single_file_path: Some(single_file_path),
227 pending_loras: Vec::new(),
228 active_lora_fingerprint: Vec::new(),
229 })
230 }
231
232 fn validate_paths(
234 &self,
235 ) -> Result<(
236 std::path::PathBuf,
237 std::path::PathBuf,
238 std::path::PathBuf,
239 std::path::PathBuf,
240 )> {
241 let clip_encoder = self
242 .base
243 .paths
244 .clip_encoder
245 .as_ref()
246 .ok_or_else(|| anyhow::anyhow!("CLIP-L encoder path required for SDXL models"))?
247 .clone();
248 let clip_tokenizer = self
249 .base
250 .paths
251 .clip_tokenizer
252 .as_ref()
253 .ok_or_else(|| anyhow::anyhow!("CLIP-L tokenizer path required for SDXL models"))?
254 .clone();
255 let clip_encoder_2 = self
256 .base
257 .paths
258 .clip_encoder_2
259 .as_ref()
260 .ok_or_else(|| anyhow::anyhow!("CLIP-G encoder path required for SDXL models"))?
261 .clone();
262 let clip_tokenizer_2 = self
263 .base
264 .paths
265 .clip_tokenizer_2
266 .as_ref()
267 .ok_or_else(|| anyhow::anyhow!("CLIP-G tokenizer path required for SDXL models"))?
268 .clone();
269
270 for (label, path) in [
271 ("transformer (UNet)", &self.base.paths.transformer),
272 ("vae", &self.base.paths.vae),
273 ("clip_encoder (CLIP-L)", &clip_encoder),
274 ("clip_tokenizer (CLIP-L)", &clip_tokenizer),
275 ("clip_encoder_2 (CLIP-G)", &clip_encoder_2),
276 ("clip_tokenizer_2 (CLIP-G)", &clip_tokenizer_2),
277 ] {
278 if !path.exists() {
279 bail!("{label} file not found: {}", path.display());
280 }
281 }
282
283 Ok((
284 clip_encoder,
285 clip_tokenizer,
286 clip_encoder_2,
287 clip_tokenizer_2,
288 ))
289 }
290
291 fn load_clip_tokenizer(
292 &self,
293 clip_tokenizer: &std::path::Path,
294 label: &str,
295 ) -> Result<Arc<tokenizers::Tokenizer>> {
296 if let Some(ref pool) = self.shared_pool {
297 return pool.lock().unwrap().load_tokenizer(clip_tokenizer);
298 }
299 Ok(Arc::new(
300 tokenizers::Tokenizer::from_file(clip_tokenizer)
301 .map_err(|e| anyhow::anyhow!("failed to load {label} tokenizer: {e}"))?,
302 ))
303 }
304
305 fn sd_config(&self) -> stable_diffusion::StableDiffusionConfig {
307 if self.is_turbo {
308 stable_diffusion::StableDiffusionConfig::sdxl_turbo(None, None, None)
309 } else {
310 stable_diffusion::StableDiffusionConfig::sdxl(None, None, None)
311 }
312 }
313
314 fn reload_unet_if_needed(&mut self) -> Result<()> {
316 let needs_reload = self
317 .base
318 .loaded
319 .as_ref()
320 .map(|l| l.unet.is_none())
321 .unwrap_or(false);
322
323 if needs_reload {
324 let sd_config = self.sd_config();
325 let loaded = self.base.loaded.as_ref().unwrap();
326 let device = loaded.device.clone();
327 let dtype = loaded.dtype;
328 let _ = loaded;
329
330 self.base.progress.stage_start("Reloading UNet (GPU)");
331 let reload_start = Instant::now();
332 let unet = self.build_unet_for_strategy(&sd_config, &device, dtype)?;
333 self.base.loaded.as_mut().unwrap().unet = Some(unet);
334 self.base
335 .progress
336 .stage_done("Reloading UNet (GPU)", reload_start.elapsed());
337 }
338 Ok(())
339 }
340
341 fn build_unet_for_strategy(
354 &self,
355 sd_config: &stable_diffusion::StableDiffusionConfig,
356 device: &Device,
357 dtype: DType,
358 ) -> Result<stable_diffusion::unet_2d::UNet2DConditionModel> {
359 let has_lora = !self.pending_loras.is_empty();
360 if let Some(single_file) = self.single_file_path.as_ref() {
361 let remap = Self::load_sdxl_remap(single_file)?;
362 if has_lora {
363 self.build_unet_single_file_with_lora(single_file, &remap, sd_config, device, dtype)
364 } else {
365 Self::build_unet_single_file(single_file, &remap, sd_config, device, dtype)
366 }
367 } else if has_lora {
368 self.build_unet_diffusers_with_lora(sd_config, device, dtype)
369 } else {
370 Ok(sd_config.build_unet(&self.base.paths.transformer, device, 4, false, dtype)?)
371 }
372 }
373
374 fn build_unet_diffusers_with_lora(
379 &self,
380 sd_config: &stable_diffusion::StableDiffusionConfig,
381 device: &Device,
382 dtype: DType,
383 ) -> Result<stable_diffusion::unet_2d::UNet2DConditionModel> {
384 use candle_core::safetensors::MmapedSafetensors;
385 use candle_nn::VarBuilder;
386
387 let st = unsafe { MmapedSafetensors::multi(&[&self.base.paths.transformer])? };
388
389 struct MmapBackend {
390 st: MmapedSafetensors,
391 }
392 impl candle_nn::var_builder::SimpleBackend for MmapBackend {
393 fn get(
394 &self,
395 _s: candle_core::Shape,
396 name: &str,
397 _h: candle_nn::Init,
398 dtype: DType,
399 dev: &Device,
400 ) -> candle_core::Result<Tensor> {
401 let t = self.st.load(name, dev)?;
402 if t.dtype() != dtype {
403 t.to_dtype(dtype)
404 } else {
405 Ok(t)
406 }
407 }
408 fn get_unchecked(
409 &self,
410 name: &str,
411 dtype: DType,
412 dev: &Device,
413 ) -> candle_core::Result<Tensor> {
414 let t = self.st.load(name, dev)?;
415 if t.dtype() != dtype {
416 t.to_dtype(dtype)
417 } else {
418 Ok(t)
419 }
420 }
421 fn contains_tensor(&self, name: &str) -> bool {
422 self.st.get(name).is_ok()
423 }
424 }
425 let inner: Box<dyn candle_nn::var_builder::SimpleBackend> = Box::new(MmapBackend { st });
426 let wrapped = self.wrap_with_loras(inner)?;
427 let vb = VarBuilder::from_backend(wrapped, dtype, device.clone());
428 Ok(stable_diffusion::unet_2d::UNet2DConditionModel::new(
429 vb,
430 4,
431 4,
432 false,
433 sd_config.unet().clone(),
434 )?)
435 }
436
437 fn build_unet_single_file_with_lora(
443 &self,
444 single_file: &std::path::Path,
445 remap: &crate::loader::SdxlRemap,
446 sd_config: &stable_diffusion::StableDiffusionConfig,
447 device: &Device,
448 dtype: DType,
449 ) -> Result<stable_diffusion::unet_2d::UNet2DConditionModel> {
450 use crate::loader::SingleFileBackend;
451 use candle_nn::VarBuilder;
452
453 let backend = SingleFileBackend::from_sdxl_unet(single_file, remap)?;
454 let inner: Box<dyn candle_nn::var_builder::SimpleBackend> = Box::new(backend);
455 let wrapped = self.wrap_with_loras(inner)?;
456 let vb = VarBuilder::from_backend(wrapped, dtype, device.clone());
457 Ok(stable_diffusion::unet_2d::UNet2DConditionModel::new(
458 vb,
459 4,
460 4,
461 false,
462 sd_config.unet().clone(),
463 )?)
464 }
465
466 fn wrap_with_loras(
470 &self,
471 inner: Box<dyn candle_nn::var_builder::SimpleBackend>,
472 ) -> Result<Box<dyn candle_nn::var_builder::SimpleBackend>> {
473 let adapters = super::lora::load_lora_adapters(&self.pending_loras, &self.base.progress)?;
474 let specs: Vec<super::lora::SdxlLoraSpec<'_>> = adapters
475 .iter()
476 .zip(self.pending_loras.iter())
477 .map(|(adapter, w)| super::lora::SdxlLoraSpec {
478 adapter: adapter.as_ref(),
479 scale: w.scale,
480 path_hash: super::lora::lora_path_hash(&w.path),
481 })
482 .collect();
483 super::lora::wrap_backend_with_lora(inner, &specs, &self.base.progress, None)
484 }
485
486 fn build_vae_for_strategy(
491 &self,
492 sd_config: &stable_diffusion::StableDiffusionConfig,
493 device: &Device,
494 dtype: DType,
495 ) -> Result<stable_diffusion::vae::AutoEncoderKL> {
496 if let Some(single_file) = self.single_file_path.as_ref() {
497 let remap = Self::load_sdxl_remap(single_file)?;
498 Self::build_vae_single_file(single_file, &remap, sd_config, device, dtype)
499 } else {
500 self.build_vae_diffusers(sd_config, device, dtype)
501 }
502 }
503
504 #[cfg(test)]
505 fn load_vae_cpu_tensors(&self) -> Result<Option<Arc<HashMap<String, Tensor>>>> {
506 self.load_vae_cpu_tensors_for_path(&self.base.paths.vae)
507 }
508
509 fn load_vae_cpu_tensors_for_path(
510 &self,
511 vae_path: &Path,
512 ) -> Result<Option<Arc<HashMap<String, Tensor>>>> {
513 let Some(shared_pool) = &self.shared_pool else {
514 return Ok(None);
515 };
516 shared_pool
517 .lock()
518 .unwrap()
519 .load_safetensors_cpu_tensors(std::slice::from_ref(&vae_path))
520 }
521
522 fn load_vae_var_builder<'a>(
523 &self,
524 vae_path: &Path,
525 dtype: DType,
526 device: &Device,
527 component: &str,
528 ) -> Result<candle_nn::VarBuilder<'a>> {
529 if let Some(tensors) = self.load_vae_cpu_tensors_for_path(vae_path)? {
530 return Ok(crate::encoders::park::varbuilder_from_parked(
531 tensors.as_ref(),
532 dtype,
533 device,
534 ));
535 }
536
537 crate::weight_loader::load_safetensors_with_progress(
538 std::slice::from_ref(&vae_path),
539 dtype,
540 device,
541 component,
542 &self.base.progress,
543 )
544 }
545
546 fn build_vae_diffusers(
547 &self,
548 sd_config: &stable_diffusion::StableDiffusionConfig,
549 device: &Device,
550 dtype: DType,
551 ) -> Result<stable_diffusion::vae::AutoEncoderKL> {
552 let vb = self.load_vae_var_builder(&self.base.paths.vae, dtype, device, "VAE")?;
553 Ok(stable_diffusion::vae::AutoEncoderKL::new(
554 vb,
555 3,
556 3,
557 sd_config.autoencoder().clone(),
558 )?)
559 }
560
561 pub fn load(&mut self) -> Result<()> {
574 if self.base.loaded.is_some() {
575 return Ok(());
576 }
577
578 if self.base.load_strategy == LoadStrategy::Sequential {
580 return Ok(());
581 }
582
583 let (clip_encoder, clip_tokenizer, clip_encoder_2, clip_tokenizer_2) =
584 self.validate_paths()?;
585
586 tracing::info!(model = %self.base.model_name, "loading SDXL model components...");
587
588 let device = crate::device::create_device(self.base.gpu_ordinal, &self.base.progress)?;
589 let dtype = if crate::device::is_gpu(&device) {
590 DType::F16
591 } else {
592 DType::F32
593 };
594
595 let sd_config = self.sd_config();
596
597 let tier1 = self
599 .pending_placement
600 .as_ref()
601 .map(|p| p.text_encoders)
602 .unwrap_or_default();
603 let clip_device = crate::device::resolve_device(Some(tier1), || Ok(device.clone()))?;
604
605 let vae_dtype = resolve_sdxl_vae_dtype(dtype, self.single_file_path.is_some());
607 let (unet, vae, clip_l, clip_g) = if let Some(single_file) = self.single_file_path.clone() {
608 self.load_components_single_file(
609 &single_file,
610 &sd_config,
611 &device,
612 &clip_device,
613 dtype,
614 vae_dtype,
615 )?
616 } else {
617 self.load_components_diffusers(
618 &clip_encoder,
619 &clip_encoder_2,
620 &sd_config,
621 &device,
622 &clip_device,
623 dtype,
624 vae_dtype,
625 )?
626 };
627
628 let tokenizer_l = self.load_clip_tokenizer(&clip_tokenizer, "CLIP-L")?;
629 let tokenizer_g = self.load_clip_tokenizer(&clip_tokenizer_2, "CLIP-G")?;
630
631 self.base.loaded = Some(LoadedSDXL {
632 unet: Some(unet),
633 vae,
634 clip_l,
635 clip_g,
636 tokenizer_l,
637 tokenizer_g,
638 sd_config,
639 device,
640 clip_device,
641 dtype,
642 vae_dtype,
643 });
644
645 tracing::info!(model = %self.base.model_name, "all SDXL components loaded successfully");
646 Ok(())
647 }
648
649 #[allow(clippy::too_many_arguments)]
651 fn load_components_diffusers(
652 &mut self,
653 clip_encoder: &std::path::Path,
654 clip_encoder_2: &std::path::Path,
655 sd_config: &stable_diffusion::StableDiffusionConfig,
656 device: &Device,
657 clip_device: &Device,
658 dtype: DType,
659 vae_dtype: DType,
660 ) -> Result<(
661 stable_diffusion::unet_2d::UNet2DConditionModel,
662 stable_diffusion::vae::AutoEncoderKL,
663 stable_diffusion::clip::ClipTextTransformer,
664 stable_diffusion::clip::ClipTextTransformer,
665 )> {
666 self.base.progress.stage_start("Loading UNet (GPU)");
667 let unet_start = Instant::now();
668 let unet = sd_config.build_unet(&self.base.paths.transformer, device, 4, false, dtype)?;
669 self.base
670 .progress
671 .stage_done("Loading UNet (GPU)", unet_start.elapsed());
672
673 self.base.progress.stage_start("Loading VAE (GPU)");
674 let vae_start = Instant::now();
675 let vae = self.build_vae_diffusers(sd_config, device, vae_dtype)?;
676 self.base
677 .progress
678 .stage_done("Loading VAE (GPU)", vae_start.elapsed());
679
680 self.base.progress.stage_start("Loading CLIP-L encoder");
681 let clip_l_start = Instant::now();
682 let clip_l = stable_diffusion::build_clip_transformer(
683 &sd_config.clip,
684 clip_encoder,
685 clip_device,
686 DType::F32,
687 )?;
688 self.base
689 .progress
690 .stage_done("Loading CLIP-L encoder", clip_l_start.elapsed());
691
692 self.base.progress.stage_start("Loading CLIP-G encoder");
693 let clip_g_start = Instant::now();
694 let clip2_config = sd_config
695 .clip2
696 .as_ref()
697 .ok_or_else(|| anyhow::anyhow!("SDXL config missing clip2 configuration"))?;
698 let clip_g = stable_diffusion::build_clip_transformer(
699 clip2_config,
700 clip_encoder_2,
701 clip_device,
702 DType::F32,
703 )?;
704 self.base
705 .progress
706 .stage_done("Loading CLIP-G encoder", clip_g_start.elapsed());
707
708 Ok((unet, vae, clip_l, clip_g))
709 }
710
711 fn load_components_single_file(
723 &mut self,
724 single_file: &std::path::Path,
725 sd_config: &stable_diffusion::StableDiffusionConfig,
726 device: &Device,
727 clip_device: &Device,
728 dtype: DType,
729 vae_dtype: DType,
730 ) -> Result<(
731 stable_diffusion::unet_2d::UNet2DConditionModel,
732 stable_diffusion::vae::AutoEncoderKL,
733 stable_diffusion::clip::ClipTextTransformer,
734 stable_diffusion::clip::ClipTextTransformer,
735 )> {
736 let remap = Self::load_sdxl_remap(single_file)?;
737
738 self.base.progress.stage_start("Loading UNet (single-file)");
739 let unet_start = Instant::now();
740 let unet = Self::build_unet_single_file(single_file, &remap, sd_config, device, dtype)?;
741 self.base
742 .progress
743 .stage_done("Loading UNet (single-file)", unet_start.elapsed());
744
745 self.base.progress.stage_start("Loading VAE (single-file)");
746 let vae_start = Instant::now();
747 let vae = Self::build_vae_single_file(single_file, &remap, sd_config, device, vae_dtype)?;
748 self.base
749 .progress
750 .stage_done("Loading VAE (single-file)", vae_start.elapsed());
751
752 self.base
753 .progress
754 .stage_start("Loading CLIP-L (single-file)");
755 let clip_l_start = Instant::now();
756 let clip_l =
757 Self::build_clip_l_single_file(single_file, &remap, &sd_config.clip, clip_device)?;
758 self.base
759 .progress
760 .stage_done("Loading CLIP-L (single-file)", clip_l_start.elapsed());
761
762 self.base
763 .progress
764 .stage_start("Loading CLIP-G (single-file)");
765 let clip_g_start = Instant::now();
766 let clip2_config = sd_config
767 .clip2
768 .as_ref()
769 .ok_or_else(|| anyhow::anyhow!("SDXL config missing clip2 configuration"))?;
770 let clip_g =
771 Self::build_clip_g_single_file(single_file, &remap, clip2_config, clip_device)?;
772 self.base
773 .progress
774 .stage_done("Loading CLIP-G (single-file)", clip_g_start.elapsed());
775
776 Ok((unet, vae, clip_l, clip_g))
777 }
778
779 fn load_sdxl_remap(single_file: &std::path::Path) -> Result<crate::loader::SdxlRemap> {
783 use crate::loader::{build_sdxl_remap, single_file as single_file_loader};
784 use mold_catalog::families::Family;
785 let bundle = single_file_loader::load(single_file, Family::Sdxl)
786 .map_err(|e| anyhow::anyhow!("partition single-file SDXL checkpoint: {e}"))?;
787 build_sdxl_remap(&bundle)
788 .map_err(|e| anyhow::anyhow!("build SDXL diffusers→A1111 remap: {e}"))
789 }
790
791 fn build_unet_single_file(
796 single_file: &std::path::Path,
797 remap: &crate::loader::SdxlRemap,
798 sd_config: &stable_diffusion::StableDiffusionConfig,
799 device: &Device,
800 dtype: DType,
801 ) -> Result<stable_diffusion::unet_2d::UNet2DConditionModel> {
802 use crate::loader::SingleFileBackend;
803 use candle_nn::VarBuilder;
804 let backend = SingleFileBackend::from_sdxl_unet(single_file, remap)?;
805 let vb = VarBuilder::from_backend(Box::new(backend), dtype, device.clone());
806 Ok(stable_diffusion::unet_2d::UNet2DConditionModel::new(
807 vb,
808 4,
809 4,
810 false,
811 sd_config.unet().clone(),
812 )?)
813 }
814
815 fn build_vae_single_file(
818 single_file: &std::path::Path,
819 remap: &crate::loader::SdxlRemap,
820 sd_config: &stable_diffusion::StableDiffusionConfig,
821 device: &Device,
822 dtype: DType,
823 ) -> Result<stable_diffusion::vae::AutoEncoderKL> {
824 use crate::loader::SingleFileBackend;
825 use candle_nn::VarBuilder;
826 let backend = SingleFileBackend::from_sdxl_vae(single_file, remap)?;
827 let vb = VarBuilder::from_backend(Box::new(backend), dtype, device.clone());
828 Ok(stable_diffusion::vae::AutoEncoderKL::new(
829 vb,
830 3,
831 3,
832 sd_config.autoencoder().clone(),
833 )?)
834 }
835
836 fn build_clip_l_single_file(
847 single_file: &std::path::Path,
848 remap: &crate::loader::SdxlRemap,
849 clip_config: &stable_diffusion::clip::Config,
850 clip_device: &Device,
851 ) -> Result<stable_diffusion::clip::ClipTextTransformer> {
852 use crate::loader::SingleFileBackend;
853 use candle_nn::VarBuilder;
854 let backend = SingleFileBackend::from_sdxl_clip_l(single_file, remap)?;
855 let vb = VarBuilder::from_backend(Box::new(backend), DType::F32, clip_device.clone());
856 Ok(stable_diffusion::clip::ClipTextTransformer::new(
857 vb,
858 clip_config,
859 )?)
860 }
861
862 fn build_clip_g_single_file(
866 single_file: &std::path::Path,
867 remap: &crate::loader::SdxlRemap,
868 clip_config: &stable_diffusion::clip::Config,
869 clip_device: &Device,
870 ) -> Result<stable_diffusion::clip::ClipTextTransformer> {
871 use crate::loader::SingleFileBackend;
872 use candle_nn::VarBuilder;
873 let backend = SingleFileBackend::from_sdxl_clip_g(single_file, remap)?;
874 let vb = VarBuilder::from_backend(Box::new(backend), DType::F32, clip_device.clone());
875 Ok(stable_diffusion::clip::ClipTextTransformer::new(
876 vb,
877 clip_config,
878 )?)
879 }
880
881 fn tokenize(
883 tokenizer: &tokenizers::Tokenizer,
884 prompt: &str,
885 max_len: usize,
886 device: &Device,
887 ) -> Result<Tensor> {
888 let encoding = tokenizer
889 .encode(prompt, true)
890 .map_err(|e| anyhow::anyhow!("tokenization failed: {e}"))?;
891 let mut ids = encoding.get_ids().to_vec();
892 ids.truncate(max_len);
893 while ids.len() < max_len {
895 ids.push(0);
896 }
897 let ids = ids.into_iter().map(|i| i as i64).collect::<Vec<_>>();
898 Ok(Tensor::new(ids, device)?.unsqueeze(0)?)
899 }
900
901 #[allow(clippy::too_many_arguments)]
905 #[allow(clippy::too_many_arguments)]
906 fn denoise_loop(
907 &self,
908 unet: &stable_diffusion::unet_2d::UNet2DConditionModel,
909 text_embeddings: &Tensor,
910 sched: Scheduler,
911 latents: &mut Tensor,
912 guidance: f64,
913 cfg_plus: bool,
914 steps: u32,
915 start_step: usize,
916 inpaint_ctx: Option<&crate::img_utils::InpaintContext>,
917 ) -> Result<()> {
918 let use_cfg = cfg_active(guidance);
919 let mut scheduler = crate::scheduler::build_scheduler(
920 sched,
921 steps as usize,
922 PredictionType::Epsilon,
923 self.is_turbo,
924 )?;
925 let timesteps = scheduler.timesteps().to_vec();
926 let active_timesteps = ×teps[start_step..];
927
928 let cfg_plus_schedule = if cfg_plus && use_cfg && matches!(sched, Scheduler::Ddim) {
933 Some(DdimAlphaSchedule::from_default(steps as usize))
934 } else {
935 if cfg_plus && !use_cfg {
936 tracing::warn!(
937 guidance,
938 "cfg_plus requested but cfg_scale ≈ 1.0 — falling back to standard step (no uncond available)"
939 );
940 } else if cfg_plus {
941 tracing::warn!(
942 scheduler = ?sched,
943 "cfg_plus requested but only DDIM is supported on SDXL/SD1.5 — falling back to standard step. Re-run with `--scheduler ddim` to enable CFG++."
944 );
945 }
946 None
947 };
948
949 let denoise_label = format!("Denoising ({} steps)", active_timesteps.len());
950 self.base.progress.stage_start(&denoise_label);
951 let denoise_start = Instant::now();
952
953 for (step_idx, &t) in active_timesteps.iter().enumerate() {
954 let step_start = std::time::Instant::now();
955 let latent_input = if use_cfg {
956 Tensor::cat(&[&*latents, &*latents], 0)?
957 } else {
958 latents.clone()
959 };
960
961 let latent_input = scheduler.scale_model_input(latent_input, t)?;
962 let noise_pred = unet.forward(&latent_input, t as f64, text_embeddings)?;
963
964 let (noise_pred_blended, noise_pred_uncond_opt) = if use_cfg {
968 let chunks = noise_pred.chunk(2, 0)?;
969 let noise_pred_uncond = chunks[0].clone();
970 let noise_pred_cond = &chunks[1];
971 let blended =
972 (&noise_pred_uncond + ((noise_pred_cond - &noise_pred_uncond)? * guidance)?)?;
973 (blended, Some(noise_pred_uncond))
974 } else {
975 (noise_pred, None)
976 };
977
978 *latents = match (cfg_plus_schedule.as_ref(), noise_pred_uncond_opt.as_ref()) {
979 (Some(ddim_sched), Some(eps_uncond)) => {
980 ddim_sched.cfg_plus_step(&*latents, &noise_pred_blended, eps_uncond, t)?
981 }
982 _ => scheduler.step(&noise_pred_blended, t, &*latents)?,
983 };
984
985 if let Some(ctx) = inpaint_ctx {
986 let noised_original =
987 scheduler.add_noise(&ctx.original_latents, ctx.noise.clone(), t)?;
988 *latents = crate::img2img::blend_inpaint_latents(&*latents, ctx, &noised_original)?;
989 }
990
991 self.base.progress.emit(ProgressEvent::DenoiseStep {
992 step: step_idx + 1,
993 total: active_timesteps.len(),
994 elapsed: step_start.elapsed(),
995 });
996 }
997
998 self.base
999 .progress
1000 .stage_done(&denoise_label, denoise_start.elapsed());
1001 Ok(())
1002 }
1003
1004 #[allow(clippy::too_many_arguments)]
1012 fn prepare_img2img_latents(
1013 &self,
1014 vae: &stable_diffusion::vae::AutoEncoderKL,
1015 source_bytes: &[u8],
1016 width: u32,
1017 height: u32,
1018 strength: f64,
1019 steps: u32,
1020 sched: Scheduler,
1021 seed: u64,
1022 device: &Device,
1023 dtype: DType,
1024 vae_dtype: DType,
1025 ) -> Result<(Tensor, usize, Tensor, Tensor)> {
1026 use crate::img_utils::{decode_source_image, NormalizeRange};
1027 let vae_scale = if self.is_turbo {
1028 VAE_SCALE_TURBO
1029 } else {
1030 VAE_SCALE_STANDARD
1031 };
1032 let cache_key = image_size_cache_key(source_bytes, width, height);
1033 let (encoded, cache_hit) = get_or_insert_cached_tensor(
1034 &self.source_latent_cache,
1035 cache_key,
1036 device,
1037 dtype,
1038 || {
1039 self.base
1040 .progress
1041 .stage_start("Encoding source image (VAE)");
1042 let encode_start = Instant::now();
1043
1044 let source_tensor = decode_source_image(
1045 source_bytes,
1046 width,
1047 height,
1048 NormalizeRange::MinusOneToOne,
1049 device,
1050 vae_dtype,
1051 )?;
1052 let encoded = vae.encode(&source_tensor)?;
1053 let encoded = (encoded.mode()? * vae_scale)?;
1054 let encoded = encoded.to_dtype(dtype)?;
1058
1059 self.base
1060 .progress
1061 .stage_done("Encoding source image (VAE)", encode_start.elapsed());
1062 Ok(encoded)
1063 },
1064 )?;
1065 if cache_hit {
1066 self.base.progress.cache_hit("source image latents");
1067 }
1068
1069 let start_step = crate::img2img::img2img_start_index(steps as usize, strength);
1070
1071 let scheduler = crate::scheduler::build_scheduler(
1072 sched,
1073 steps as usize,
1074 PredictionType::Epsilon,
1075 self.is_turbo,
1076 )?;
1077 let timesteps = scheduler.timesteps().to_vec();
1078
1079 let latent_h = height as usize / 8;
1080 let latent_w = width as usize / 8;
1081 let noise =
1082 crate::engine::seeded_randn(seed, &[1, 4, latent_h, latent_w], device, DType::F32)?;
1083 let noise = noise.to_dtype(dtype)?;
1084
1085 let noised = if start_step < timesteps.len() {
1086 scheduler.add_noise(&encoded, noise.clone(), timesteps[start_step])?
1087 } else {
1088 encoded.clone()
1089 };
1090
1091 tracing::info!(
1092 start_step,
1093 total_steps = steps,
1094 strength,
1095 "img2img: starting from step {start_step}"
1096 );
1097
1098 Ok((noised, start_step, encoded, noise))
1099 }
1100
1101 #[allow(clippy::too_many_arguments)]
1103 fn encode_prompt(
1104 &self,
1105 clip_l: &stable_diffusion::clip::ClipTextTransformer,
1106 clip_g: &stable_diffusion::clip::ClipTextTransformer,
1107 tokenizer_l: &tokenizers::Tokenizer,
1108 tokenizer_g: &tokenizers::Tokenizer,
1109 prompt: &str,
1110 negative_prompt: &str,
1111 max_len: usize,
1112 device: &Device,
1113 clip_device: &Device,
1114 dtype: DType,
1115 guidance: f64,
1116 ) -> Result<Tensor> {
1117 let cache_key = cfg_prompt_cache_key(prompt, negative_prompt, guidance);
1123 let (text_embeddings, cache_hit) =
1124 get_or_insert_cached_tensor(&self.prompt_cache, cache_key, device, dtype, || {
1125 let use_cfg = cfg_active(guidance);
1126
1127 self.base.progress.stage_start("Encoding prompt (CLIP-L)");
1128 let encode_l_start = Instant::now();
1129 let tokens_l = Self::tokenize(tokenizer_l, prompt, max_len, clip_device)?;
1130 let text_emb_l = clip_l.forward(&tokens_l)?;
1131 self.base
1132 .progress
1133 .stage_done("Encoding prompt (CLIP-L)", encode_l_start.elapsed());
1134
1135 self.base.progress.stage_start("Encoding prompt (CLIP-G)");
1136 let encode_g_start = Instant::now();
1137 let tokens_g = Self::tokenize(tokenizer_g, prompt, max_len, clip_device)?;
1138 let text_emb_g = clip_g.forward(&tokens_g)?;
1139 self.base
1140 .progress
1141 .stage_done("Encoding prompt (CLIP-G)", encode_g_start.elapsed());
1142
1143 let text_embeddings = Tensor::cat(&[&text_emb_l, &text_emb_g], D::Minus1)?;
1144
1145 let text_embeddings = if use_cfg {
1146 let uncond_tokens_l =
1147 Self::tokenize(tokenizer_l, negative_prompt, max_len, clip_device)?;
1148 let uncond_emb_l = clip_l.forward(&uncond_tokens_l)?;
1149 let uncond_tokens_g =
1150 Self::tokenize(tokenizer_g, negative_prompt, max_len, clip_device)?;
1151 let uncond_emb_g = clip_g.forward(&uncond_tokens_g)?;
1152 let uncond_embeddings =
1153 Tensor::cat(&[&uncond_emb_l, &uncond_emb_g], D::Minus1)?;
1154 Tensor::cat(&[&uncond_embeddings, &text_embeddings], 0)?
1155 } else {
1156 text_embeddings
1157 };
1158
1159 let text_embeddings = text_embeddings.to_device(device)?;
1160 Ok(text_embeddings.to_dtype(dtype)?)
1161 })?;
1162 if cache_hit {
1163 self.base.progress.cache_hit("prompt conditioning");
1164 return Ok(text_embeddings);
1165 }
1166 Ok(text_embeddings)
1167 }
1168
1169 fn cached_mask(
1170 &self,
1171 mask_bytes: &[u8],
1172 latent_h: usize,
1173 latent_w: usize,
1174 device: &Device,
1175 dtype: DType,
1176 ) -> Result<Tensor> {
1177 let key = latent_size_cache_key(mask_bytes, latent_h, latent_w);
1178 let (mask, cache_hit) =
1179 get_or_insert_cached_tensor(&self.mask_cache, key, device, dtype, || {
1180 crate::img_utils::decode_mask_image(mask_bytes, latent_h, latent_w, device, dtype)
1181 })?;
1182 if cache_hit {
1183 self.base.progress.cache_hit("inpaint mask");
1184 return Ok(mask);
1185 }
1186 Ok(mask)
1187 }
1188
1189 fn generate_sequential(&mut self, req: &GenerateRequest) -> Result<GenerateResponse> {
1197 let (clip_encoder, clip_tokenizer, clip_encoder_2, clip_tokenizer_2) =
1198 self.validate_paths()?;
1199
1200 if let Some(warning) = check_memory_budget(&self.base.paths, LoadStrategy::Sequential) {
1202 self.base.progress.info(&warning);
1203 }
1204
1205 let device = crate::device::create_device(self.base.gpu_ordinal, &self.base.progress)?;
1206 let dtype = if crate::device::is_gpu(&device) {
1207 DType::F16
1208 } else {
1209 DType::F32
1210 };
1211
1212 let sd_config = self.sd_config();
1213 let max_len = sd_config.clip.max_position_embeddings;
1214
1215 let start = Instant::now();
1216 let seed = req.seed.unwrap_or_else(rand_seed);
1217
1218 let width = req.width as usize;
1219 let height = req.height as usize;
1220 let guidance = req.guidance;
1221
1222 tracing::info!(
1223 prompt = %req.prompt,
1224 seed, width, height,
1225 steps = req.steps,
1226 guidance,
1227 "starting sequential SDXL generation"
1228 );
1229
1230 self.base
1231 .progress
1232 .info("Using sequential loading (load-use-drop) to minimize peak memory");
1233
1234 let neg = req.negative_prompt.as_deref().unwrap_or("");
1236 let cache_key = cfg_prompt_cache_key(&req.prompt, neg, guidance);
1237 let text_embeddings = if let Some(tensor) =
1238 restore_cached_tensor(&self.prompt_cache, &cache_key, &device, dtype)?
1239 {
1240 self.base.progress.cache_hit("prompt conditioning");
1241 tensor
1242 } else {
1243 if let Some(status) = memory_status_string() {
1244 self.base.progress.info(&status);
1245 }
1246
1247 let tokenizer_l = self.load_clip_tokenizer(&clip_tokenizer, "CLIP-L")?;
1248 let tokenizer_g = self.load_clip_tokenizer(&clip_tokenizer_2, "CLIP-G")?;
1249
1250 let tier1 = self
1251 .pending_placement
1252 .as_ref()
1253 .map(|p| p.text_encoders)
1254 .unwrap_or_default();
1255 let clip_device = crate::device::resolve_device(Some(tier1), || Ok(device.clone()))?;
1256
1257 let (clip_l, clip_g) =
1263 if let Some(single_file) = self.single_file_path.clone() {
1264 let remap = Self::load_sdxl_remap(&single_file)?;
1265
1266 self.base
1267 .progress
1268 .stage_start("Loading CLIP-L (single-file)");
1269 let clip_l_start = Instant::now();
1270 let clip_l = Self::build_clip_l_single_file(
1271 &single_file,
1272 &remap,
1273 &sd_config.clip,
1274 &clip_device,
1275 )?;
1276 self.base
1277 .progress
1278 .stage_done("Loading CLIP-L (single-file)", clip_l_start.elapsed());
1279
1280 self.base
1281 .progress
1282 .stage_start("Loading CLIP-G (single-file)");
1283 let clip_g_start = Instant::now();
1284 let clip2_config = sd_config.clip2.as_ref().ok_or_else(|| {
1285 anyhow::anyhow!("SDXL config missing clip2 configuration")
1286 })?;
1287 let clip_g = Self::build_clip_g_single_file(
1288 &single_file,
1289 &remap,
1290 clip2_config,
1291 &clip_device,
1292 )?;
1293 self.base
1294 .progress
1295 .stage_done("Loading CLIP-G (single-file)", clip_g_start.elapsed());
1296
1297 (clip_l, clip_g)
1298 } else {
1299 self.base.progress.stage_start("Loading CLIP-L encoder");
1300 let clip_l_start = Instant::now();
1301 let clip_l = stable_diffusion::build_clip_transformer(
1302 &sd_config.clip,
1303 &clip_encoder,
1304 &clip_device,
1305 DType::F32,
1306 )?;
1307 self.base
1308 .progress
1309 .stage_done("Loading CLIP-L encoder", clip_l_start.elapsed());
1310
1311 self.base.progress.stage_start("Loading CLIP-G encoder");
1312 let clip_g_start = Instant::now();
1313 let clip2_config = sd_config.clip2.as_ref().ok_or_else(|| {
1314 anyhow::anyhow!("SDXL config missing clip2 configuration")
1315 })?;
1316 let clip_g = stable_diffusion::build_clip_transformer(
1317 clip2_config,
1318 &clip_encoder_2,
1319 &clip_device,
1320 DType::F32,
1321 )?;
1322 self.base
1323 .progress
1324 .stage_done("Loading CLIP-G encoder", clip_g_start.elapsed());
1325
1326 (clip_l, clip_g)
1327 };
1328
1329 let text_embeddings = self.encode_prompt(
1330 &clip_l,
1331 &clip_g,
1332 &tokenizer_l,
1333 &tokenizer_g,
1334 &req.prompt,
1335 neg,
1336 max_len,
1337 &device,
1338 &clip_device,
1339 dtype,
1340 guidance,
1341 )?;
1342
1343 drop(clip_l);
1344 drop(clip_g);
1345 self.base.progress.info("Freed CLIP-L and CLIP-G encoders");
1346 tracing::info!("CLIP encoders dropped (sequential mode)");
1347
1348 text_embeddings
1349 };
1350
1351 let unet_size = std::fs::metadata(&self.base.paths.transformer)
1353 .map(|m| m.len())
1354 .unwrap_or(0);
1355 let unet_batch = if req.guidance > 1.0 { 2 } else { 1 };
1357 let unet_activation_budget = crate::device::activation_bytes(
1358 req.width,
1359 req.height,
1360 unet_batch,
1361 crate::device::dtype_bytes(dtype),
1362 crate::device::ActivationFamily::SdxlUnet,
1363 );
1364 preflight_memory_check("UNet", unet_size, unet_activation_budget)?;
1365 if let Some(status) = memory_status_string() {
1366 self.base.progress.info(&status);
1367 }
1368
1369 self.base.progress.stage_start("Loading UNet (GPU)");
1370 let unet_start = Instant::now();
1371 let unet = self.build_unet_for_strategy(&sd_config, &device, dtype)?;
1372 self.base
1373 .progress
1374 .stage_done("Loading UNet (GPU)", unet_start.elapsed());
1375
1376 let sched = req.scheduler.unwrap_or(self.scheduler);
1377 let is_img2img = req.source_image.is_some();
1378
1379 let (mut latents, start_step, inpaint_ctx) = if let Some(ref source_bytes) =
1380 req.source_image
1381 {
1382 self.base
1383 .progress
1384 .info("img2img mode: encoding source image before denoising");
1385
1386 self.base.progress.stage_start("Loading VAE (GPU)");
1387 let vae_start_t = Instant::now();
1388 let vae_dtype = resolve_sdxl_vae_dtype(dtype, self.single_file_path.is_some());
1389 let vae = self.build_vae_for_strategy(&sd_config, &device, vae_dtype)?;
1390 self.base
1391 .progress
1392 .stage_done("Loading VAE (GPU)", vae_start_t.elapsed());
1393
1394 let (latents, start_step, encoded, noise) = self.prepare_img2img_latents(
1395 &vae,
1396 source_bytes,
1397 req.width,
1398 req.height,
1399 req.strength,
1400 req.steps,
1401 sched,
1402 seed,
1403 &device,
1404 dtype,
1405 vae_dtype,
1406 )?;
1407
1408 let inpaint_ctx = if let Some(ref mask_bytes) = req.mask_image {
1409 let mask = self.cached_mask(mask_bytes, height / 8, width / 8, &device, dtype)?;
1410 Some(crate::img_utils::InpaintContext {
1411 original_latents: encoded,
1412 mask,
1413 noise,
1414 })
1415 } else {
1416 None
1417 };
1418
1419 drop(vae);
1420 self.base
1421 .progress
1422 .info("Freed VAE (will reload for decode)");
1423 device.synchronize()?;
1424
1425 (latents, start_step, inpaint_ctx)
1426 } else {
1427 let latent_h = height / 8;
1428 let latent_w = width / 8;
1429 let init_scheduler = crate::scheduler::build_scheduler(
1430 sched,
1431 req.steps as usize,
1432 PredictionType::Epsilon,
1433 self.is_turbo,
1434 )?;
1435 let init_noise_sigma = init_scheduler.init_noise_sigma();
1436 drop(init_scheduler);
1437 let latents = (crate::engine::seeded_randn(
1438 seed,
1439 &[1, 4, latent_h, latent_w],
1440 &device,
1441 DType::F32,
1442 )? * init_noise_sigma)?;
1443 (latents.to_dtype(dtype)?, 0, None)
1444 };
1445
1446 self.denoise_loop(
1447 &unet,
1448 &text_embeddings,
1449 sched,
1450 &mut latents,
1451 guidance,
1452 resolve_cfg_plus(req),
1453 req.steps,
1454 start_step,
1455 inpaint_ctx.as_ref(),
1456 )?;
1457
1458 drop(inpaint_ctx);
1459 drop(unet);
1460 drop(text_embeddings);
1461 device.synchronize()?;
1462 self.base.progress.info("Freed UNet");
1463 tracing::info!("UNet dropped (sequential mode)");
1464
1465 let vae_load_label = if is_img2img {
1467 "Reloading VAE (GPU)"
1468 } else {
1469 "Loading VAE (GPU)"
1470 };
1471 self.base.progress.stage_start(vae_load_label);
1472 let vae_start = Instant::now();
1473 let vae_dtype = resolve_sdxl_vae_dtype(dtype, self.single_file_path.is_some());
1474 let vae = self.build_vae_for_strategy(&sd_config, &device, vae_dtype)?;
1475 self.base
1476 .progress
1477 .stage_done(vae_load_label, vae_start.elapsed());
1478
1479 self.base.progress.stage_start("VAE decode");
1480 let vae_decode_start = Instant::now();
1481
1482 let vae_scale = if self.is_turbo {
1483 VAE_SCALE_TURBO
1484 } else {
1485 VAE_SCALE_STANDARD
1486 };
1487 let latents = (latents / vae_scale)?;
1488 let latents_for_vae = latents.to_dtype(vae_dtype)?;
1489 let device_for_sync = device.clone();
1490 let img = crate::vae_tiling::decode_with_oom_fallback(
1491 &latents_for_vae,
1492 |t| vae.decode(t).map_err(Into::into),
1493 || {
1494 if let Err(e) = device_for_sync.synchronize() {
1495 tracing::warn!(
1496 "SDXL (sequential) device.synchronize() after VAE OOM failed: {e}"
1497 );
1498 }
1499 },
1500 )?;
1501
1502 let img = ((img / 2.)? + 0.5)?.clamp(0f32, 1f32)?;
1503 let img = (img * 255.)?.to_dtype(DType::U8)?;
1504 let img = img.squeeze(0)?;
1505
1506 self.base
1507 .progress
1508 .stage_done("VAE decode", vae_decode_start.elapsed());
1509
1510 let output_metadata = build_output_metadata(req, seed, Some(sched));
1512 let image_bytes = encode_image(
1513 &img,
1514 req.resolved_output_format(),
1515 req.width,
1516 req.height,
1517 output_metadata.as_ref(),
1518 )?;
1519
1520 let generation_time_ms = start.elapsed().as_millis() as u64;
1521 tracing::info!(
1522 generation_time_ms,
1523 seed,
1524 "sequential SDXL generation complete"
1525 );
1526
1527 Ok(GenerateResponse {
1528 images: vec![ImageData {
1529 data: image_bytes,
1530 format: req.resolved_output_format(),
1531 width: req.width,
1532 height: req.height,
1533 index: 0,
1534 }],
1535 generation_time_ms,
1536 model: req.model.clone(),
1537 seed_used: seed,
1538 video: None,
1539 gpu: None,
1540 })
1541 }
1542}
1543
1544impl SDXLEngine {
1545 fn generate_inner(&mut self, req: &GenerateRequest) -> Result<GenerateResponse> {
1546 if self.base.load_strategy == LoadStrategy::Sequential {
1548 return self.generate_sequential(req);
1549 }
1550
1551 let requested_stack = lora_stack_fingerprint(&self.pending_loras);
1555 if requested_stack != self.active_lora_fingerprint {
1556 if let Some(loaded) = self.base.loaded.as_mut() {
1557 if loaded.unet.is_some() {
1558 loaded.unet = None;
1559 loaded.device.synchronize()?;
1560 tracing::info!("SDXL UNet dropped (LoRA stack changed)");
1561 }
1562 }
1563 self.active_lora_fingerprint = requested_stack;
1564 }
1565
1566 self.reload_unet_if_needed()?;
1568
1569 let loaded = self
1570 .base
1571 .loaded
1572 .as_ref()
1573 .ok_or_else(|| anyhow::anyhow!("model not loaded — call load() first"))?;
1574
1575 let start = Instant::now();
1576 let seed = req.seed.unwrap_or_else(rand_seed);
1577
1578 let width = req.width as usize;
1579 let height = req.height as usize;
1580 let guidance = req.guidance;
1581
1582 tracing::info!(
1583 prompt = %req.prompt,
1584 seed, width, height,
1585 steps = req.steps,
1586 guidance,
1587 scheduler = %self.scheduler,
1588 "starting SDXL generation"
1589 );
1590
1591 let max_len = loaded.sd_config.clip.max_position_embeddings;
1593 let neg = req.negative_prompt.as_deref().unwrap_or("");
1594 let text_embeddings = self.encode_prompt(
1595 &loaded.clip_l,
1596 &loaded.clip_g,
1597 &loaded.tokenizer_l,
1598 &loaded.tokenizer_g,
1599 &req.prompt,
1600 neg,
1601 max_len,
1602 &loaded.device,
1603 &loaded.clip_device,
1604 loaded.dtype,
1605 guidance,
1606 )?;
1607
1608 let sched = req.scheduler.unwrap_or(self.scheduler);
1610
1611 let (mut latents, start_step, inpaint_ctx) =
1612 if let Some(ref source_bytes) = req.source_image {
1613 let (latents, start_step, encoded, noise) = self.prepare_img2img_latents(
1614 &loaded.vae,
1615 source_bytes,
1616 req.width,
1617 req.height,
1618 req.strength,
1619 req.steps,
1620 sched,
1621 seed,
1622 &loaded.device,
1623 loaded.dtype,
1624 loaded.vae_dtype,
1625 )?;
1626 let inpaint_ctx = if let Some(ref mask_bytes) = req.mask_image {
1627 let mask = self.cached_mask(
1628 mask_bytes,
1629 height / 8,
1630 width / 8,
1631 &loaded.device,
1632 loaded.dtype,
1633 )?;
1634 Some(crate::img_utils::InpaintContext {
1635 original_latents: encoded,
1636 mask,
1637 noise,
1638 })
1639 } else {
1640 None
1641 };
1642 (latents, start_step, inpaint_ctx)
1643 } else {
1644 let latent_h = height / 8;
1645 let latent_w = width / 8;
1646 let init_scheduler = crate::scheduler::build_scheduler(
1647 sched,
1648 req.steps as usize,
1649 PredictionType::Epsilon,
1650 self.is_turbo,
1651 )?;
1652 let init_noise_sigma = init_scheduler.init_noise_sigma();
1653 drop(init_scheduler);
1654 let latents = (crate::engine::seeded_randn(
1655 seed,
1656 &[1, 4, latent_h, latent_w],
1657 &loaded.device,
1658 DType::F32,
1659 )? * init_noise_sigma)?;
1660 (latents.to_dtype(loaded.dtype)?, 0, None)
1661 };
1662
1663 let unet = loaded
1665 .unet
1666 .as_ref()
1667 .ok_or_else(|| anyhow::anyhow!("UNet not loaded"))?;
1668 self.denoise_loop(
1669 unet,
1670 &text_embeddings,
1671 sched,
1672 &mut latents,
1673 guidance,
1674 resolve_cfg_plus(req),
1675 req.steps,
1676 start_step,
1677 inpaint_ctx.as_ref(),
1678 )?;
1679
1680 drop(inpaint_ctx);
1682 let _ = loaded;
1683 let loaded = self.base.loaded.as_mut().unwrap();
1684 loaded.unet = None;
1685 loaded.device.synchronize()?;
1686 tracing::info!("UNet dropped to free VRAM for VAE decode");
1687 let _ = loaded;
1688 let loaded = self.base.loaded.as_ref().unwrap();
1689
1690 self.base.progress.stage_start("VAE decode");
1692 let vae_start = Instant::now();
1693
1694 let vae_scale = if self.is_turbo {
1695 VAE_SCALE_TURBO
1696 } else {
1697 VAE_SCALE_STANDARD
1698 };
1699 let latents = (latents / vae_scale)?;
1700 let latents_for_vae = latents.to_dtype(loaded.vae_dtype)?;
1701 let vae = &loaded.vae;
1702 let device_for_sync = loaded.device.clone();
1703 let img = crate::vae_tiling::decode_with_oom_fallback(
1704 &latents_for_vae,
1705 |t| vae.decode(t).map_err(Into::into),
1706 || {
1707 if let Err(e) = device_for_sync.synchronize() {
1708 tracing::warn!(
1709 "SDXL (parallel) device.synchronize() after VAE OOM failed: {e}"
1710 );
1711 }
1712 },
1713 )?;
1714
1715 let img = ((img / 2.)? + 0.5)?.clamp(0f32, 1f32)?;
1717 let img = (img * 255.)?.to_dtype(DType::U8)?;
1718 let img = img.squeeze(0)?; self.base
1721 .progress
1722 .stage_done("VAE decode", vae_start.elapsed());
1723
1724 let output_metadata = build_output_metadata(req, seed, Some(sched));
1726 let image_bytes = encode_image(
1727 &img,
1728 req.resolved_output_format(),
1729 req.width,
1730 req.height,
1731 output_metadata.as_ref(),
1732 )?;
1733
1734 let generation_time_ms = start.elapsed().as_millis() as u64;
1735 tracing::info!(generation_time_ms, seed, "SDXL generation complete");
1736
1737 Ok(GenerateResponse {
1738 images: vec![ImageData {
1739 data: image_bytes,
1740 format: req.resolved_output_format(),
1741 width: req.width,
1742 height: req.height,
1743 index: 0,
1744 }],
1745 generation_time_ms,
1746 model: req.model.clone(),
1747 seed_used: seed,
1748 video: None,
1749 gpu: None,
1750 })
1751 }
1752}
1753
1754impl InferenceEngine for SDXLEngine {
1755 fn generate(&mut self, req: &GenerateRequest) -> Result<GenerateResponse> {
1756 self.pending_placement = req.placement.clone();
1757 self.pending_loras = super::lora::effective_sdxl_loras(req);
1758 let result = self.generate_inner(req);
1759 self.pending_placement = None;
1760 self.pending_loras.clear();
1761 result
1762 }
1763
1764 fn model_name(&self) -> &str {
1765 self.base.model_name()
1766 }
1767
1768 fn is_loaded(&self) -> bool {
1769 self.base.is_loaded()
1771 }
1772
1773 fn load(&mut self) -> Result<()> {
1774 SDXLEngine::load(self)
1775 }
1776
1777 fn unload(&mut self) {
1778 self.base.unload();
1779 clear_cache(&self.prompt_cache);
1780 clear_cache(&self.source_latent_cache);
1781 clear_cache(&self.mask_cache);
1782 self.active_lora_fingerprint.clear();
1783 }
1784
1785 fn set_on_progress(&mut self, callback: ProgressCallback) {
1786 self.base.set_on_progress(callback);
1787 }
1788
1789 fn clear_on_progress(&mut self) {
1790 self.base.clear_on_progress();
1791 }
1792
1793 fn model_paths(&self) -> Option<&mold_core::ModelPaths> {
1794 Some(&self.base.paths)
1795 }
1796}
1797
1798#[cfg(test)]
1799mod tests {
1800 use super::*;
1801 use crate::engine::InferenceEngine;
1802 use crate::shared_pool::SharedPool;
1803 use safetensors::tensor::{serialize_to_file, Dtype as SafeDtype, TensorView};
1804 use std::collections::HashMap;
1805 use std::sync::{Arc, Mutex};
1806 use tokenizers::models::bpe::BPE;
1807
1808 fn synth_sdxl_single_file(name: &str) -> PathBuf {
1813 let path = std::env::temp_dir().join(format!(
1814 "mold-sdxl-from-sf-{}-{}-{}.safetensors",
1815 name,
1816 std::process::id(),
1817 std::time::SystemTime::now()
1818 .duration_since(std::time::UNIX_EPOCH)
1819 .unwrap()
1820 .as_nanos(),
1821 ));
1822
1823 let keys: &[&str] = &[
1824 "model.diffusion_model.input_blocks.0.0.weight",
1826 "model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight",
1827 "first_stage_model.encoder.down.0.block.0.norm1.weight",
1829 "first_stage_model.quant_conv.weight",
1830 "conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.q_proj.weight",
1832 "conditioner.embedders.0.transformer.text_model.final_layer_norm.weight",
1833 "conditioner.embedders.1.model.transformer.resblocks.0.attn.in_proj_weight",
1836 "conditioner.embedders.1.model.text_projection",
1837 ];
1838
1839 let f32_zero = 0.0f32.to_le_bytes().to_vec();
1840 let buffers: Vec<Vec<u8>> = keys.iter().map(|_| f32_zero.clone()).collect();
1841 let mut tensors: HashMap<String, TensorView<'_>> = HashMap::new();
1842 for (key, buf) in keys.iter().zip(buffers.iter()) {
1843 tensors.insert(
1844 (*key).to_string(),
1845 TensorView::new(SafeDtype::F32, vec![1], buf).unwrap(),
1846 );
1847 }
1848 serialize_to_file(&tensors, &None, &path).unwrap();
1849 path
1850 }
1851
1852 #[test]
1853 fn from_single_file_constructs_for_synthetic_sdxl_checkpoint() {
1854 let single_file = synth_sdxl_single_file("ok");
1855 let clip_l_tok = std::env::temp_dir().join("mold-sdxl-clip-l-stub.json");
1859 let clip_g_tok = std::env::temp_dir().join("mold-sdxl-clip-g-stub.json");
1860
1861 let engine = SDXLEngine::from_single_file(
1862 "juggernaut-xl-v9".to_string(),
1863 single_file.clone(),
1864 clip_l_tok,
1865 clip_g_tok,
1866 Scheduler::default(),
1867 false,
1868 LoadStrategy::Eager,
1869 0,
1870 None,
1871 )
1872 .expect("constructor must accept a valid SDXL single-file layout");
1873
1874 assert_eq!(engine.model_name(), "juggernaut-xl-v9");
1875 assert_eq!(
1876 engine.single_file_path.as_deref(),
1877 Some(single_file.as_path()),
1878 "single-file path must be stashed for the future load() branch",
1879 );
1880 assert!(
1881 !engine.is_loaded(),
1882 "constructor must not eagerly materialise model weights",
1883 );
1884
1885 let _ = std::fs::remove_file(single_file);
1886 }
1887
1888 #[test]
1889 fn sdxl_loads_clip_tokenizers_through_shared_pool() {
1890 let dir = tempfile::tempdir().unwrap();
1891 let clip_l_tokenizer = dir.path().join("clip-l-tokenizer.json");
1892 let clip_g_tokenizer = dir.path().join("clip-g-tokenizer.json");
1893 tokenizers::Tokenizer::new(BPE::default())
1894 .save(&clip_l_tokenizer, false)
1895 .unwrap();
1896 tokenizers::Tokenizer::new(BPE::default())
1897 .save(&clip_g_tokenizer, false)
1898 .unwrap();
1899 let weights_path = dir.path().join("weights.safetensors");
1900 std::fs::write(&weights_path, b"stub").unwrap();
1901
1902 let shared_pool = Arc::new(Mutex::new(SharedPool::new()));
1903 let pooled_l = shared_pool
1904 .lock()
1905 .unwrap()
1906 .load_tokenizer(&clip_l_tokenizer)
1907 .unwrap();
1908 let pooled_g = shared_pool
1909 .lock()
1910 .unwrap()
1911 .load_tokenizer(&clip_g_tokenizer)
1912 .unwrap();
1913
1914 let paths = ModelPaths {
1915 transformer: weights_path.clone(),
1916 transformer_shards: Vec::new(),
1917 vae: weights_path.clone(),
1918 spatial_upscaler: None,
1919 temporal_upscaler: None,
1920 distilled_lora: None,
1921 t5_encoder: None,
1922 clip_encoder: Some(weights_path.clone()),
1923 t5_tokenizer: None,
1924 clip_tokenizer: Some(clip_l_tokenizer.clone()),
1925 clip_encoder_2: Some(weights_path),
1926 clip_tokenizer_2: Some(clip_g_tokenizer.clone()),
1927 text_encoder_files: Vec::new(),
1928 text_tokenizer: None,
1929 decoder: None,
1930 };
1931 let engine = SDXLEngine::new(
1932 "sdxl-test".to_string(),
1933 paths,
1934 Scheduler::default(),
1935 false,
1936 LoadStrategy::Eager,
1937 0,
1938 Some(shared_pool),
1939 );
1940
1941 let loaded_l = engine
1942 .load_clip_tokenizer(&clip_l_tokenizer, "CLIP-L")
1943 .unwrap();
1944 let loaded_g = engine
1945 .load_clip_tokenizer(&clip_g_tokenizer, "CLIP-G")
1946 .unwrap();
1947
1948 assert!(Arc::ptr_eq(&pooled_l, &loaded_l));
1949 assert!(Arc::ptr_eq(&pooled_g, &loaded_g));
1950 }
1951
1952 #[test]
1953 fn sdxl_loads_vae_tensors_through_shared_pool() {
1954 let dir = tempfile::tempdir().unwrap();
1955 let vae_path = dir.path().join("vae.safetensors");
1956 let weight = 1.0f32.to_le_bytes();
1957 let mut tensors = HashMap::new();
1958 tensors.insert(
1959 "encoder.conv_in.weight".to_string(),
1960 TensorView::new(SafeDtype::F32, vec![1], &weight).unwrap(),
1961 );
1962 serialize_to_file(&tensors, &None, &vae_path).unwrap();
1963
1964 let shared_pool = Arc::new(Mutex::new(SharedPool::new()));
1965 let pooled = shared_pool
1966 .lock()
1967 .unwrap()
1968 .load_safetensors_cpu_tensors(std::slice::from_ref(&vae_path))
1969 .unwrap()
1970 .unwrap();
1971
1972 let paths = ModelPaths {
1973 transformer: dir.path().join("unet.safetensors"),
1974 transformer_shards: Vec::new(),
1975 vae: vae_path.clone(),
1976 spatial_upscaler: None,
1977 temporal_upscaler: None,
1978 distilled_lora: None,
1979 t5_encoder: None,
1980 clip_encoder: Some(dir.path().join("clip-l.safetensors")),
1981 t5_tokenizer: None,
1982 clip_tokenizer: Some(dir.path().join("clip-l-tokenizer.json")),
1983 clip_encoder_2: Some(dir.path().join("clip-g.safetensors")),
1984 clip_tokenizer_2: Some(dir.path().join("clip-g-tokenizer.json")),
1985 text_encoder_files: Vec::new(),
1986 text_tokenizer: None,
1987 decoder: None,
1988 };
1989 let engine = SDXLEngine::new(
1990 "sdxl-test".to_string(),
1991 paths,
1992 Scheduler::default(),
1993 false,
1994 LoadStrategy::Eager,
1995 0,
1996 Some(shared_pool),
1997 );
1998
1999 let loaded = engine.load_vae_cpu_tensors().unwrap().unwrap();
2000
2001 assert!(Arc::ptr_eq(&pooled, &loaded));
2002 }
2003
2004 #[test]
2005 fn from_single_file_rejects_missing_file() {
2006 let bogus = std::env::temp_dir().join(format!(
2007 "mold-sdxl-from-sf-missing-{}-{}.safetensors",
2008 std::process::id(),
2009 std::time::SystemTime::now()
2010 .duration_since(std::time::UNIX_EPOCH)
2011 .unwrap()
2012 .as_nanos(),
2013 ));
2014
2015 let result = SDXLEngine::from_single_file(
2016 "missing".to_string(),
2017 bogus,
2018 std::env::temp_dir().join("mold-sdxl-clip-l-stub.json"),
2019 std::env::temp_dir().join("mold-sdxl-clip-g-stub.json"),
2020 Scheduler::default(),
2021 false,
2022 LoadStrategy::Eager,
2023 0,
2024 None,
2025 );
2026
2027 assert!(
2028 result.is_err(),
2029 "constructor must surface a missing-file error before deeper parsing",
2030 );
2031 }
2032
2033 #[test]
2034 fn load_branches_to_single_file_path_and_invokes_candle_constructors() {
2035 let single_file = synth_sdxl_single_file("load-branch");
2045 let make_stub = |label: &str| -> PathBuf {
2046 let path = std::env::temp_dir().join(format!(
2047 "mold-sdxl-{}-stub-{}-{}.json",
2048 label,
2049 std::process::id(),
2050 std::time::SystemTime::now()
2051 .duration_since(std::time::UNIX_EPOCH)
2052 .unwrap()
2053 .as_nanos(),
2054 ));
2055 std::fs::write(&path, b"").unwrap();
2056 path
2057 };
2058 let clip_l_tok = make_stub("clip-l");
2059 let clip_g_tok = make_stub("clip-g");
2060
2061 let mut engine = SDXLEngine::from_single_file(
2062 "juggernaut-xl-v9".to_string(),
2063 single_file.clone(),
2064 clip_l_tok.clone(),
2065 clip_g_tok.clone(),
2066 Scheduler::Ddim,
2067 false,
2068 LoadStrategy::Eager,
2069 0,
2070 None,
2071 )
2072 .expect("constructor");
2073
2074 std::env::set_var("MOLD_DEVICE", "cpu");
2075 let err = SDXLEngine::load(&mut engine)
2076 .expect_err("synthetic checkpoint can't satisfy SDXL's full tensor set");
2077 std::env::remove_var("MOLD_DEVICE");
2078
2079 let msg = err.to_string();
2080 assert!(
2081 msg.contains("single-file") || msg.contains("rename rule"),
2082 "expected error from the single-file load layer, got: {msg}",
2083 );
2084
2085 let _ = std::fs::remove_file(single_file);
2086 let _ = std::fs::remove_file(clip_l_tok);
2087 let _ = std::fs::remove_file(clip_g_tok);
2088 }
2089
2090 #[test]
2094 #[ignore]
2095 fn from_single_file_real_shape_load_smoke() {
2096 }
2098
2099 #[test]
2116 fn build_clip_l_single_file_dispatches_through_backend_not_diffusers_loader() {
2117 let single_file = synth_sdxl_single_file("seq-clip-l-dispatch");
2118 let remap = SDXLEngine::load_sdxl_remap(&single_file).expect("remap");
2119
2120 let result = SDXLEngine::build_clip_l_single_file(
2121 &single_file,
2122 &remap,
2123 &stable_diffusion::clip::Config::sdxl(),
2124 &Device::Cpu,
2125 );
2126
2127 let err = result.expect_err(
2128 "synthetic CLIP-L is missing token_embedding / position_embedding / \
2129 every encoder layer beyond layer 0 — construction must fail",
2130 );
2131 let msg = err.to_string();
2132 assert!(
2133 !msg.contains("cannot find tensor text_model"),
2134 "expected failure from the SingleFileBackend layer (e.g. 'no rename rule \
2135 for diffusers key text_model.embeddings.token_embedding.weight'); got the \
2136 diffusers `from_mmaped_safetensors` error instead — sequential dispatch \
2137 is still routing through `build_clip_transformer`. Got: {msg}",
2138 );
2139
2140 let _ = std::fs::remove_file(single_file);
2141 }
2142
2143 #[test]
2146 fn build_clip_g_single_file_dispatches_through_backend_not_diffusers_loader() {
2147 let single_file = synth_sdxl_single_file("seq-clip-g-dispatch");
2148 let remap = SDXLEngine::load_sdxl_remap(&single_file).expect("remap");
2149
2150 let result = SDXLEngine::build_clip_g_single_file(
2151 &single_file,
2152 &remap,
2153 &stable_diffusion::clip::Config::sdxl2(),
2154 &Device::Cpu,
2155 );
2156
2157 let err = result.expect_err("synthetic CLIP-G is incomplete");
2158 let msg = err.to_string();
2159 assert!(
2160 !msg.contains("cannot find tensor text_model"),
2161 "expected failure from SingleFileBackend, not diffusers loader. Got: {msg}",
2162 );
2163
2164 let _ = std::fs::remove_file(single_file);
2165 }
2166
2167 #[test]
2172 fn build_unet_single_file_dispatches_through_backend_not_diffusers_loader() {
2173 let single_file = synth_sdxl_single_file("seq-unet-dispatch");
2174 let remap = SDXLEngine::load_sdxl_remap(&single_file).expect("remap");
2175
2176 let result = SDXLEngine::build_unet_single_file(
2177 &single_file,
2178 &remap,
2179 &stable_diffusion::StableDiffusionConfig::sdxl(None, None, None),
2180 &Device::Cpu,
2181 DType::F32,
2182 );
2183
2184 let err = result.expect_err("synthetic UNet is incomplete");
2185 let msg = err.to_string();
2186 assert!(
2187 !msg.contains("cannot find tensor conv_in"),
2188 "expected failure from SingleFileBackend, not diffusers loader. Got: {msg}",
2189 );
2190
2191 let _ = std::fs::remove_file(single_file);
2192 }
2193
2194 #[test]
2197 fn from_single_file_threads_is_turbo_true() {
2198 let single_file = synth_sdxl_single_file("turbo");
2199 let clip_l_tok = std::env::temp_dir().join("mold-sdxl-turbo-clip-l-stub.json");
2200 let clip_g_tok = std::env::temp_dir().join("mold-sdxl-turbo-clip-g-stub.json");
2201
2202 let engine = SDXLEngine::from_single_file(
2203 "sdxl-turbo:fp16".to_string(),
2204 single_file.clone(),
2205 clip_l_tok,
2206 clip_g_tok,
2207 Scheduler::EulerAncestral,
2208 true,
2209 LoadStrategy::Eager,
2210 0,
2211 None,
2212 )
2213 .expect("constructor must accept is_turbo = true");
2214
2215 assert!(
2216 engine.is_turbo,
2217 "is_turbo arg must thread into the engine field — sdxl_config() reads this for VAE_SCALE_TURBO",
2218 );
2219
2220 let _ = std::fs::remove_file(single_file);
2221 }
2222
2223 #[test]
2224 fn from_single_file_threads_is_turbo_false() {
2225 let single_file = synth_sdxl_single_file("standard");
2226 let clip_l_tok = std::env::temp_dir().join("mold-sdxl-std-clip-l-stub.json");
2227 let clip_g_tok = std::env::temp_dir().join("mold-sdxl-std-clip-g-stub.json");
2228
2229 let engine = SDXLEngine::from_single_file(
2230 "sdxl-base:fp16".to_string(),
2231 single_file.clone(),
2232 clip_l_tok,
2233 clip_g_tok,
2234 Scheduler::Ddim,
2235 false,
2236 LoadStrategy::Eager,
2237 0,
2238 None,
2239 )
2240 .expect("constructor must accept is_turbo = false");
2241
2242 assert!(
2243 !engine.is_turbo,
2244 "is_turbo = false must produce a standard-config engine",
2245 );
2246
2247 let _ = std::fs::remove_file(single_file);
2248 }
2249
2250 #[test]
2251 fn single_file_sdxl_vae_defaults_to_f32_to_avoid_black_finetune_decodes() {
2252 unsafe { std::env::remove_var("MOLD_VAE_DTYPE") };
2253 assert_eq!(resolve_sdxl_vae_dtype(DType::F16, true), DType::F32);
2254 assert_eq!(resolve_sdxl_vae_dtype(DType::F16, false), DType::F16);
2255 }
2256
2257 #[test]
2261 fn test_cfg_disabled_at_guidance_1_0() {
2262 assert!(!cfg_active(1.0));
2263 }
2264
2265 #[test]
2266 fn test_cfg_disabled_just_below_1_0() {
2267 assert!(!cfg_active(1.0 - 1e-5));
2268 }
2269
2270 #[test]
2271 fn test_cfg_enabled_at_guidance_1_5() {
2272 assert!(cfg_active(1.5));
2273 }
2274
2275 #[test]
2276 fn test_cfg_enabled_at_guidance_7_5() {
2277 assert!(cfg_active(7.5));
2278 }
2279
2280 #[test]
2285 fn lora_stack_fingerprint_equality_drives_unet_drop() {
2286 let a = mold_core::LoraWeight {
2287 path: "/x.safetensors".into(),
2288 scale: 0.8,
2289 };
2290 let b = mold_core::LoraWeight {
2291 path: "/y.safetensors".into(),
2292 scale: 0.4,
2293 };
2294 let same_a = mold_core::LoraWeight {
2295 path: "/x.safetensors".into(),
2296 scale: 0.8,
2297 };
2298 assert_eq!(
2300 lora_stack_fingerprint(&[a.clone(), b.clone()]),
2301 lora_stack_fingerprint(&[same_a.clone(), b.clone()])
2302 );
2303 let scaled = mold_core::LoraWeight {
2305 path: "/x.safetensors".into(),
2306 scale: 0.9,
2307 };
2308 assert_ne!(
2309 lora_stack_fingerprint(std::slice::from_ref(&a)),
2310 lora_stack_fingerprint(std::slice::from_ref(&scaled))
2311 );
2312 assert_ne!(
2317 lora_stack_fingerprint(&[a.clone(), b.clone()]),
2318 lora_stack_fingerprint(&[b, a])
2319 );
2320 }
2321
2322 #[test]
2326 fn sdxl_prompt_cache_distinguishes_negative_prompt_changes() {
2327 use crate::cache::{cfg_prompt_cache_key, store_cached_tensor};
2328
2329 let cache: Mutex<LruCache<CfgPromptCacheKey, CachedTensor>> =
2330 Mutex::new(LruCache::new(DEFAULT_PROMPT_CACHE_CAPACITY));
2331 let device = Device::Cpu;
2332 let dtype = DType::F32;
2333 let embeddings = candle_core::Tensor::zeros((1, 4), dtype, &device).unwrap();
2334
2335 let key_a = cfg_prompt_cache_key("a cat", "blurry", 7.0);
2336 store_cached_tensor(&cache, key_a.clone(), &embeddings).unwrap();
2337
2338 let key_b = cfg_prompt_cache_key("a cat", "low quality", 7.0);
2340 let restored = restore_cached_tensor(&cache, &key_b, &device, dtype).unwrap();
2341 assert!(
2342 restored.is_none(),
2343 "different negative prompt must miss the cache (silent-wrong-output bug)"
2344 );
2345
2346 let restored = restore_cached_tensor(&cache, &key_a, &device, dtype).unwrap();
2348 assert!(
2349 restored.is_some(),
2350 "identical (pos, neg, guidance) must hit"
2351 );
2352 }
2353}