1use anyhow::{bail, Result};
2use candle_core::{DType, Device, Module, Tensor};
3use candle_transformers::models::stable_diffusion;
4use candle_transformers::models::stable_diffusion::schedulers::PredictionType;
5use mold_core::{GenerateRequest, GenerateResponse, ImageData, ModelPaths, Scheduler};
6use std::collections::HashMap;
7use std::path::{Path, PathBuf};
8use std::sync::{Arc, Mutex};
9use std::time::Instant;
10
11use crate::cache::{
12 cfg_prompt_cache_key, clear_cache, get_or_insert_cached_tensor, image_size_cache_key,
13 latent_size_cache_key, restore_cached_tensor, CachedTensor, CfgPromptCacheKey,
14 ImageSizeCacheKey, LatentSizeCacheKey, LruCache, DEFAULT_IMAGE_CACHE_CAPACITY,
15 DEFAULT_PROMPT_CACHE_CAPACITY,
16};
17use crate::cfg_plus_ddim::DdimAlphaSchedule;
18use crate::controlnet::ControlNetModel;
19use crate::device::{check_memory_budget, memory_status_string, preflight_memory_check};
20use crate::engine::{cfg_active, rand_seed, resolve_cfg_plus, InferenceEngine, LoadStrategy};
21use crate::engine_base::EngineBase;
22use crate::image::{build_output_metadata, encode_image};
23use crate::progress::{ProgressCallback, ProgressEvent};
24
25const VAE_SCALE: f64 = 0.18215;
27
28struct ControlNetContext {
30 model: ControlNetModel,
31 control_tensor: Tensor,
32 scale: f64,
33}
34
35struct LoadedSD15 {
37 unet: Option<stable_diffusion::unet_2d::UNet2DConditionModel>,
39 vae: stable_diffusion::vae::AutoEncoderKL,
40 clip: stable_diffusion::clip::ClipTextTransformer,
41 tokenizer: Arc<tokenizers::Tokenizer>,
42 sd_config: stable_diffusion::StableDiffusionConfig,
43 device: Device,
44 clip_device: Device,
47 dtype: DType,
48 vae_dtype: DType,
51}
52
53pub struct SD15Engine {
57 base: EngineBase<LoadedSD15>,
58 scheduler: Scheduler,
59 shared_pool: Option<Arc<Mutex<crate::shared_pool::SharedPool>>>,
60 prompt_cache: Mutex<LruCache<CfgPromptCacheKey, CachedTensor>>,
61 source_latent_cache: Mutex<LruCache<ImageSizeCacheKey, CachedTensor>>,
62 mask_cache: Mutex<LruCache<LatentSizeCacheKey, CachedTensor>>,
63 control_tensor_cache: Mutex<LruCache<ImageSizeCacheKey, CachedTensor>>,
64 pending_placement: Option<mold_core::types::DevicePlacement>,
65 pub(crate) single_file_path: Option<PathBuf>,
72 pending_loras: Vec<mold_core::LoraWeight>,
78 active_lora_fingerprint: Vec<(String, u64)>,
84}
85
86fn lora_stack_fingerprint(loras: &[mold_core::LoraWeight]) -> Vec<(String, u64)> {
91 loras
92 .iter()
93 .map(|w| (w.path.clone(), w.scale.to_bits()))
94 .collect()
95}
96
97impl SD15Engine {
98 pub fn new(
99 model_name: String,
100 paths: ModelPaths,
101 scheduler: Scheduler,
102 load_strategy: LoadStrategy,
103 gpu_ordinal: usize,
104 shared_pool: Option<Arc<Mutex<crate::shared_pool::SharedPool>>>,
105 ) -> Self {
106 Self {
107 base: EngineBase::new(model_name, paths, load_strategy, gpu_ordinal),
108 scheduler,
109 shared_pool,
110 prompt_cache: Mutex::new(LruCache::new(DEFAULT_PROMPT_CACHE_CAPACITY)),
111 source_latent_cache: Mutex::new(LruCache::new(DEFAULT_IMAGE_CACHE_CAPACITY)),
112 mask_cache: Mutex::new(LruCache::new(DEFAULT_IMAGE_CACHE_CAPACITY)),
113 control_tensor_cache: Mutex::new(LruCache::new(DEFAULT_IMAGE_CACHE_CAPACITY)),
114 pending_placement: None,
115 single_file_path: None,
116 pending_loras: Vec::new(),
117 active_lora_fingerprint: Vec::new(),
118 }
119 }
120
121 pub fn from_single_file(
138 model_name: String,
139 single_file_path: PathBuf,
140 clip_tokenizer: PathBuf,
141 scheduler: Scheduler,
142 load_strategy: LoadStrategy,
143 gpu_ordinal: usize,
144 shared_pool: Option<Arc<Mutex<crate::shared_pool::SharedPool>>>,
145 ) -> Result<Self> {
146 if !single_file_path.exists() {
147 bail!(
148 "single-file checkpoint not found: {}",
149 single_file_path.display()
150 );
151 }
152
153 let bundle = crate::loader::single_file::load(
156 &single_file_path,
157 mold_catalog::families::Family::Sd15,
158 )?;
159
160 let _remap = crate::loader::sd15_keys::build_sd15_remap(&bundle)?;
165
166 let paths = ModelPaths {
172 transformer: single_file_path.clone(),
173 transformer_shards: Vec::new(),
174 vae: single_file_path.clone(),
175 spatial_upscaler: None,
176 temporal_upscaler: None,
177 distilled_lora: None,
178 t5_encoder: None,
179 clip_encoder: Some(single_file_path.clone()),
180 t5_tokenizer: None,
181 clip_tokenizer: Some(clip_tokenizer),
182 clip_encoder_2: None,
183 clip_tokenizer_2: None,
184 text_encoder_files: Vec::new(),
185 text_tokenizer: None,
186 decoder: None,
187 };
188
189 Ok(Self {
190 base: EngineBase::new(model_name, paths, load_strategy, gpu_ordinal),
191 scheduler,
192 shared_pool,
193 prompt_cache: Mutex::new(LruCache::new(DEFAULT_PROMPT_CACHE_CAPACITY)),
194 source_latent_cache: Mutex::new(LruCache::new(DEFAULT_IMAGE_CACHE_CAPACITY)),
195 mask_cache: Mutex::new(LruCache::new(DEFAULT_IMAGE_CACHE_CAPACITY)),
196 control_tensor_cache: Mutex::new(LruCache::new(DEFAULT_IMAGE_CACHE_CAPACITY)),
197 pending_placement: None,
198 single_file_path: Some(single_file_path),
199 pending_loras: Vec::new(),
200 active_lora_fingerprint: Vec::new(),
201 })
202 }
203
204 fn validate_paths(&self) -> Result<(std::path::PathBuf, std::path::PathBuf)> {
206 let clip_encoder = self
207 .base
208 .paths
209 .clip_encoder
210 .as_ref()
211 .ok_or_else(|| anyhow::anyhow!("CLIP-L encoder path required for SD1.5 models"))?
212 .clone();
213 let clip_tokenizer = self
214 .base
215 .paths
216 .clip_tokenizer
217 .as_ref()
218 .ok_or_else(|| anyhow::anyhow!("CLIP-L tokenizer path required for SD1.5 models"))?
219 .clone();
220
221 for (label, path) in [
222 ("transformer (UNet)", &self.base.paths.transformer),
223 ("vae", &self.base.paths.vae),
224 ("clip_encoder (CLIP-L)", &clip_encoder),
225 ("clip_tokenizer (CLIP-L)", &clip_tokenizer),
226 ] {
227 if !path.exists() {
228 bail!("{label} file not found: {}", path.display());
229 }
230 }
231
232 Ok((clip_encoder, clip_tokenizer))
233 }
234
235 fn load_clip_tokenizer(
236 &self,
237 clip_tokenizer: &std::path::Path,
238 ) -> Result<Arc<tokenizers::Tokenizer>> {
239 if let Some(ref pool) = self.shared_pool {
240 return pool.lock().unwrap().load_tokenizer(clip_tokenizer);
241 }
242 Ok(Arc::new(
243 tokenizers::Tokenizer::from_file(clip_tokenizer)
244 .map_err(|e| anyhow::anyhow!("failed to load CLIP-L tokenizer: {e}"))?,
245 ))
246 }
247
248 fn sd_config(&self) -> stable_diffusion::StableDiffusionConfig {
250 stable_diffusion::StableDiffusionConfig::v1_5(None, None, None)
251 }
252
253 fn reload_unet_if_needed(&mut self) -> Result<()> {
255 let needs_reload = self
256 .base
257 .loaded
258 .as_ref()
259 .map(|l| l.unet.is_none())
260 .unwrap_or(false);
261
262 if needs_reload {
263 let sd_config = self.sd_config();
264 let loaded = self.base.loaded.as_ref().unwrap();
265 let device = loaded.device.clone();
266 let dtype = loaded.dtype;
267 let _ = loaded;
268
269 self.base.progress.stage_start("Reloading UNet (GPU)");
270 let reload_start = Instant::now();
271 let unet = self.build_unet_for_strategy(&sd_config, &device, dtype)?;
272 self.base.loaded.as_mut().unwrap().unet = Some(unet);
273 self.base
274 .progress
275 .stage_done("Reloading UNet (GPU)", reload_start.elapsed());
276 }
277 Ok(())
278 }
279
280 fn build_unet_for_strategy(
293 &self,
294 sd_config: &stable_diffusion::StableDiffusionConfig,
295 device: &Device,
296 dtype: DType,
297 ) -> Result<stable_diffusion::unet_2d::UNet2DConditionModel> {
298 let has_lora = !self.pending_loras.is_empty();
299 if let Some(single_file) = self.single_file_path.as_ref() {
300 let remap = Self::load_sd15_remap(single_file)?;
301 if has_lora {
302 self.build_unet_single_file_with_lora(single_file, &remap, sd_config, device, dtype)
303 } else {
304 Self::build_unet_single_file(single_file, &remap, sd_config, device, dtype)
305 }
306 } else if has_lora {
307 self.build_unet_diffusers_with_lora(sd_config, device, dtype)
308 } else {
309 Ok(sd_config.build_unet(&self.base.paths.transformer, device, 4, false, dtype)?)
310 }
311 }
312
313 fn build_unet_diffusers_with_lora(
318 &self,
319 sd_config: &stable_diffusion::StableDiffusionConfig,
320 device: &Device,
321 dtype: DType,
322 ) -> Result<stable_diffusion::unet_2d::UNet2DConditionModel> {
323 use candle_core::safetensors::MmapedSafetensors;
324 use candle_nn::VarBuilder;
325
326 let st = unsafe { MmapedSafetensors::multi(&[&self.base.paths.transformer])? };
327
328 struct MmapBackend {
329 st: MmapedSafetensors,
330 }
331 impl candle_nn::var_builder::SimpleBackend for MmapBackend {
332 fn get(
333 &self,
334 _s: candle_core::Shape,
335 name: &str,
336 _h: candle_nn::Init,
337 dtype: DType,
338 dev: &Device,
339 ) -> candle_core::Result<Tensor> {
340 let t = self.st.load(name, dev)?;
341 if t.dtype() != dtype {
342 t.to_dtype(dtype)
343 } else {
344 Ok(t)
345 }
346 }
347 fn get_unchecked(
348 &self,
349 name: &str,
350 dtype: DType,
351 dev: &Device,
352 ) -> candle_core::Result<Tensor> {
353 let t = self.st.load(name, dev)?;
354 if t.dtype() != dtype {
355 t.to_dtype(dtype)
356 } else {
357 Ok(t)
358 }
359 }
360 fn contains_tensor(&self, name: &str) -> bool {
361 self.st.get(name).is_ok()
362 }
363 }
364 let inner: Box<dyn candle_nn::var_builder::SimpleBackend> = Box::new(MmapBackend { st });
365 let wrapped = self.wrap_with_loras(inner)?;
366 let vb = VarBuilder::from_backend(wrapped, dtype, device.clone());
367 Ok(stable_diffusion::unet_2d::UNet2DConditionModel::new(
368 vb,
369 4,
370 4,
371 false,
372 sd_config.unet().clone(),
373 )?)
374 }
375
376 fn build_unet_single_file_with_lora(
382 &self,
383 single_file: &std::path::Path,
384 remap: &crate::loader::Sd15Remap,
385 sd_config: &stable_diffusion::StableDiffusionConfig,
386 device: &Device,
387 dtype: DType,
388 ) -> Result<stable_diffusion::unet_2d::UNet2DConditionModel> {
389 use crate::loader::SingleFileBackend;
390 use candle_nn::VarBuilder;
391
392 let backend = SingleFileBackend::from_sd15_unet(single_file, remap)?;
393 let inner: Box<dyn candle_nn::var_builder::SimpleBackend> = Box::new(backend);
394 let wrapped = self.wrap_with_loras(inner)?;
395 let vb = VarBuilder::from_backend(wrapped, dtype, device.clone());
396 Ok(stable_diffusion::unet_2d::UNet2DConditionModel::new(
397 vb,
398 4,
399 4,
400 false,
401 sd_config.unet().clone(),
402 )?)
403 }
404
405 fn wrap_with_loras(
409 &self,
410 inner: Box<dyn candle_nn::var_builder::SimpleBackend>,
411 ) -> Result<Box<dyn candle_nn::var_builder::SimpleBackend>> {
412 let adapters = super::lora::load_lora_adapters(&self.pending_loras, &self.base.progress)?;
413 let specs: Vec<super::lora::Sd15LoraSpec<'_>> = adapters
414 .iter()
415 .zip(self.pending_loras.iter())
416 .map(|(adapter, w)| super::lora::Sd15LoraSpec {
417 adapter: adapter.as_ref(),
418 scale: w.scale,
419 path_hash: super::lora::lora_path_hash(&w.path),
420 })
421 .collect();
422 super::lora::wrap_backend_with_lora(inner, &specs, &self.base.progress, None)
423 }
424
425 fn build_vae_for_strategy(
430 &self,
431 sd_config: &stable_diffusion::StableDiffusionConfig,
432 device: &Device,
433 dtype: DType,
434 ) -> Result<stable_diffusion::vae::AutoEncoderKL> {
435 if let Some(single_file) = self.single_file_path.as_ref() {
436 let remap = Self::load_sd15_remap(single_file)?;
437 Self::build_vae_single_file(single_file, &remap, sd_config, device, dtype)
438 } else {
439 self.build_vae_diffusers(sd_config, device, dtype)
440 }
441 }
442
443 #[cfg(test)]
444 fn load_vae_cpu_tensors(&self) -> Result<Option<Arc<HashMap<String, Tensor>>>> {
445 self.load_vae_cpu_tensors_for_path(&self.base.paths.vae)
446 }
447
448 fn load_vae_cpu_tensors_for_path(
449 &self,
450 vae_path: &Path,
451 ) -> Result<Option<Arc<HashMap<String, Tensor>>>> {
452 let Some(shared_pool) = &self.shared_pool else {
453 return Ok(None);
454 };
455 shared_pool
456 .lock()
457 .unwrap()
458 .load_safetensors_cpu_tensors(std::slice::from_ref(&vae_path))
459 }
460
461 fn load_vae_var_builder<'a>(
462 &self,
463 vae_path: &Path,
464 dtype: DType,
465 device: &Device,
466 component: &str,
467 ) -> Result<candle_nn::VarBuilder<'a>> {
468 if let Some(tensors) = self.load_vae_cpu_tensors_for_path(vae_path)? {
469 return Ok(crate::encoders::park::varbuilder_from_parked(
470 tensors.as_ref(),
471 dtype,
472 device,
473 ));
474 }
475
476 crate::weight_loader::load_safetensors_with_progress(
477 std::slice::from_ref(&vae_path),
478 dtype,
479 device,
480 component,
481 &self.base.progress,
482 )
483 }
484
485 fn build_vae_diffusers(
486 &self,
487 sd_config: &stable_diffusion::StableDiffusionConfig,
488 device: &Device,
489 dtype: DType,
490 ) -> Result<stable_diffusion::vae::AutoEncoderKL> {
491 let vb = self.load_vae_var_builder(&self.base.paths.vae, dtype, device, "VAE")?;
492 Ok(stable_diffusion::vae::AutoEncoderKL::new(
493 vb,
494 3,
495 3,
496 sd_config.autoencoder().clone(),
497 )?)
498 }
499
500 fn load_sd15_remap(single_file: &std::path::Path) -> Result<crate::loader::Sd15Remap> {
504 use crate::loader::{build_sd15_remap, single_file as single_file_loader};
505 use mold_catalog::families::Family;
506 let bundle = single_file_loader::load(single_file, Family::Sd15)
507 .map_err(|e| anyhow::anyhow!("partition single-file SD1.5 checkpoint: {e}"))?;
508 build_sd15_remap(&bundle)
509 .map_err(|e| anyhow::anyhow!("build SD1.5 diffusers→A1111 remap: {e}"))
510 }
511
512 fn build_unet_single_file(
518 single_file: &std::path::Path,
519 remap: &crate::loader::Sd15Remap,
520 sd_config: &stable_diffusion::StableDiffusionConfig,
521 device: &Device,
522 dtype: DType,
523 ) -> Result<stable_diffusion::unet_2d::UNet2DConditionModel> {
524 use crate::loader::SingleFileBackend;
525 use candle_nn::VarBuilder;
526 let backend = SingleFileBackend::from_sd15_unet(single_file, remap)?;
527 let vb = VarBuilder::from_backend(Box::new(backend), dtype, device.clone());
528 Ok(stable_diffusion::unet_2d::UNet2DConditionModel::new(
529 vb,
530 4,
531 4,
532 false,
533 sd_config.unet().clone(),
534 )?)
535 }
536
537 fn build_vae_single_file(
540 single_file: &std::path::Path,
541 remap: &crate::loader::Sd15Remap,
542 sd_config: &stable_diffusion::StableDiffusionConfig,
543 device: &Device,
544 dtype: DType,
545 ) -> Result<stable_diffusion::vae::AutoEncoderKL> {
546 use crate::loader::SingleFileBackend;
547 use candle_nn::VarBuilder;
548 let backend = SingleFileBackend::from_sd15_vae(single_file, remap)?;
549 let vb = VarBuilder::from_backend(Box::new(backend), dtype, device.clone());
550 Ok(stable_diffusion::vae::AutoEncoderKL::new(
551 vb,
552 3,
553 3,
554 sd_config.autoencoder().clone(),
555 )?)
556 }
557
558 fn build_clip_single_file(
561 single_file: &std::path::Path,
562 remap: &crate::loader::Sd15Remap,
563 clip_config: &stable_diffusion::clip::Config,
564 clip_device: &Device,
565 ) -> Result<stable_diffusion::clip::ClipTextTransformer> {
566 use crate::loader::SingleFileBackend;
567 use candle_nn::VarBuilder;
568 let backend = SingleFileBackend::from_sd15_clip_l(single_file, remap)?;
569 let vb = VarBuilder::from_backend(Box::new(backend), DType::F32, clip_device.clone());
570 Ok(stable_diffusion::clip::ClipTextTransformer::new(
571 vb,
572 clip_config,
573 )?)
574 }
575
576 pub fn load(&mut self) -> Result<()> {
596 if self.base.loaded.is_some() {
597 return Ok(());
598 }
599
600 if self.base.load_strategy == LoadStrategy::Sequential {
602 return Ok(());
603 }
604
605 let (clip_encoder, clip_tokenizer) = self.validate_paths()?;
606
607 tracing::info!(model = %self.base.model_name, "loading SD1.5 model components...");
608
609 let device = crate::device::create_device(self.base.gpu_ordinal, &self.base.progress)?;
610 let dtype = if crate::device::is_gpu(&device) {
611 DType::F16
612 } else {
613 DType::F32
614 };
615
616 let sd_config = self.sd_config();
617
618 let tier1 = self
620 .pending_placement
621 .as_ref()
622 .map(|p| p.text_encoders)
623 .unwrap_or_default();
624 let clip_device = crate::device::resolve_device(Some(tier1), || Ok(device.clone()))?;
625
626 let vae_dtype = crate::device::resolve_vae_dtype(dtype);
628 let (unet, vae, clip) = if let Some(single_file) = self.single_file_path.clone() {
629 self.load_components_single_file(
630 &single_file,
631 &sd_config,
632 &device,
633 &clip_device,
634 dtype,
635 vae_dtype,
636 )?
637 } else {
638 self.load_components_diffusers(
639 &clip_encoder,
640 &sd_config,
641 &device,
642 &clip_device,
643 dtype,
644 vae_dtype,
645 )?
646 };
647
648 let tokenizer = self.load_clip_tokenizer(&clip_tokenizer)?;
649
650 self.base.loaded = Some(LoadedSD15 {
651 unet: Some(unet),
652 vae,
653 clip,
654 tokenizer,
655 sd_config,
656 device,
657 clip_device,
658 vae_dtype,
659 dtype,
660 });
661
662 tracing::info!(model = %self.base.model_name, "all SD1.5 components loaded successfully");
663 Ok(())
664 }
665
666 #[allow(clippy::too_many_arguments)]
668 fn load_components_diffusers(
669 &mut self,
670 clip_encoder: &std::path::Path,
671 sd_config: &stable_diffusion::StableDiffusionConfig,
672 device: &Device,
673 clip_device: &Device,
674 dtype: DType,
675 vae_dtype: DType,
676 ) -> Result<(
677 stable_diffusion::unet_2d::UNet2DConditionModel,
678 stable_diffusion::vae::AutoEncoderKL,
679 stable_diffusion::clip::ClipTextTransformer,
680 )> {
681 self.base.progress.stage_start("Loading UNet (GPU)");
682 let unet_start = Instant::now();
683 let unet = sd_config.build_unet(
684 &self.base.paths.transformer,
685 device,
686 4, false, dtype,
689 )?;
690 self.base
691 .progress
692 .stage_done("Loading UNet (GPU)", unet_start.elapsed());
693
694 self.base.progress.stage_start("Loading VAE (GPU)");
695 let vae_start = Instant::now();
696 let vae = self.build_vae_diffusers(sd_config, device, vae_dtype)?;
697 self.base
698 .progress
699 .stage_done("Loading VAE (GPU)", vae_start.elapsed());
700
701 self.base.progress.stage_start("Loading CLIP-L encoder");
702 let clip_start = Instant::now();
703 let clip = stable_diffusion::build_clip_transformer(
704 &sd_config.clip,
705 clip_encoder,
706 clip_device,
707 DType::F32,
708 )?;
709 self.base
710 .progress
711 .stage_done("Loading CLIP-L encoder", clip_start.elapsed());
712
713 Ok((unet, vae, clip))
714 }
715
716 #[allow(clippy::too_many_arguments)]
728 fn load_components_single_file(
729 &mut self,
730 single_file: &std::path::Path,
731 sd_config: &stable_diffusion::StableDiffusionConfig,
732 device: &Device,
733 clip_device: &Device,
734 dtype: DType,
735 vae_dtype: DType,
736 ) -> Result<(
737 stable_diffusion::unet_2d::UNet2DConditionModel,
738 stable_diffusion::vae::AutoEncoderKL,
739 stable_diffusion::clip::ClipTextTransformer,
740 )> {
741 let remap = Self::load_sd15_remap(single_file)?;
742
743 self.base.progress.stage_start("Loading UNet (single-file)");
744 let unet_start = Instant::now();
745 let unet = Self::build_unet_single_file(single_file, &remap, sd_config, device, dtype)?;
746 self.base
747 .progress
748 .stage_done("Loading UNet (single-file)", unet_start.elapsed());
749
750 self.base.progress.stage_start("Loading VAE (single-file)");
751 let vae_start = Instant::now();
752 let vae = Self::build_vae_single_file(single_file, &remap, sd_config, device, vae_dtype)?;
753 self.base
754 .progress
755 .stage_done("Loading VAE (single-file)", vae_start.elapsed());
756
757 self.base
758 .progress
759 .stage_start("Loading CLIP-L (single-file)");
760 let clip_start = Instant::now();
761 let clip = Self::build_clip_single_file(single_file, &remap, &sd_config.clip, clip_device)?;
762 self.base
763 .progress
764 .stage_done("Loading CLIP-L (single-file)", clip_start.elapsed());
765
766 Ok((unet, vae, clip))
767 }
768
769 fn tokenize(
771 tokenizer: &tokenizers::Tokenizer,
772 prompt: &str,
773 max_len: usize,
774 device: &Device,
775 ) -> Result<Tensor> {
776 let encoding = tokenizer
777 .encode(prompt, true)
778 .map_err(|e| anyhow::anyhow!("tokenization failed: {e}"))?;
779 let mut ids = encoding.get_ids().to_vec();
780 ids.truncate(max_len);
781 while ids.len() < max_len {
783 ids.push(0);
784 }
785 let ids = ids.into_iter().map(|i| i as i64).collect::<Vec<_>>();
786 Ok(Tensor::new(ids, device)?.unsqueeze(0)?)
787 }
788
789 #[allow(clippy::too_many_arguments)]
793 fn denoise_loop(
794 &self,
795 unet: &stable_diffusion::unet_2d::UNet2DConditionModel,
796 text_embeddings: &Tensor,
797 sched: Scheduler,
798 latents: &mut Tensor,
799 guidance: f64,
800 cfg_plus: bool,
801 steps: u32,
802 start_step: usize,
803 inpaint_ctx: Option<&crate::img_utils::InpaintContext>,
804 controlnet_ctx: Option<&ControlNetContext>,
805 ) -> Result<()> {
806 let use_cfg = cfg_active(guidance);
807 let mut scheduler = crate::scheduler::build_scheduler(
808 sched,
809 steps as usize,
810 PredictionType::Epsilon,
811 false,
812 )?;
813 let timesteps = scheduler.timesteps().to_vec();
814 let active_timesteps = ×teps[start_step..];
815
816 let cfg_plus_schedule = if cfg_plus && use_cfg && matches!(sched, Scheduler::Ddim) {
819 Some(DdimAlphaSchedule::from_default(steps as usize))
820 } else {
821 if cfg_plus && !use_cfg {
822 tracing::warn!(
823 guidance,
824 "cfg_plus requested but cfg_scale ≈ 1.0 — falling back to standard step (no uncond available)"
825 );
826 } else if cfg_plus {
827 tracing::warn!(
828 scheduler = ?sched,
829 "cfg_plus requested but only DDIM is supported on SDXL/SD1.5 — falling back to standard step. Re-run with `--scheduler ddim` to enable CFG++."
830 );
831 }
832 None
833 };
834
835 let denoise_label = format!("Denoising ({} steps)", active_timesteps.len());
836 self.base.progress.stage_start(&denoise_label);
837 let denoise_start = Instant::now();
838
839 for (step_idx, &t) in active_timesteps.iter().enumerate() {
840 let step_start = std::time::Instant::now();
841 let latent_input = if use_cfg {
842 Tensor::cat(&[&*latents, &*latents], 0)?
843 } else {
844 latents.clone()
845 };
846
847 let latent_input = scheduler.scale_model_input(latent_input, t)?;
848
849 let noise_pred = if let Some(cn_ctx) = controlnet_ctx {
850 let (down_residuals, mid_residual) = cn_ctx.model.forward(
851 &latent_input,
852 t as f64,
853 text_embeddings,
854 &cn_ctx.control_tensor,
855 cn_ctx.scale,
856 )?;
857 unet.forward_with_additional_residuals(
858 &latent_input,
859 t as f64,
860 text_embeddings,
861 Some(&down_residuals),
862 Some(&mid_residual),
863 )?
864 } else {
865 unet.forward(&latent_input, t as f64, text_embeddings)?
866 };
867
868 let (noise_pred_blended, noise_pred_uncond_opt) = if use_cfg {
871 let chunks = noise_pred.chunk(2, 0)?;
872 let noise_pred_uncond = chunks[0].clone();
873 let noise_pred_cond = &chunks[1];
874 let blended =
875 (&noise_pred_uncond + ((noise_pred_cond - &noise_pred_uncond)? * guidance)?)?;
876 (blended, Some(noise_pred_uncond))
877 } else {
878 (noise_pred, None)
879 };
880
881 *latents = match (cfg_plus_schedule.as_ref(), noise_pred_uncond_opt.as_ref()) {
882 (Some(ddim_sched), Some(eps_uncond)) => {
883 ddim_sched.cfg_plus_step(&*latents, &noise_pred_blended, eps_uncond, t)?
884 }
885 _ => scheduler.step(&noise_pred_blended, t, &*latents)?,
886 };
887
888 if let Some(ctx) = inpaint_ctx {
890 let noised_original =
891 scheduler.add_noise(&ctx.original_latents, ctx.noise.clone(), t)?;
892 *latents = crate::img2img::blend_inpaint_latents(&*latents, ctx, &noised_original)?;
893 }
894
895 self.base.progress.emit(ProgressEvent::DenoiseStep {
896 step: step_idx + 1,
897 total: active_timesteps.len(),
898 elapsed: step_start.elapsed(),
899 });
900 }
901
902 self.base
903 .progress
904 .stage_done(&denoise_label, denoise_start.elapsed());
905 Ok(())
906 }
907
908 #[allow(clippy::too_many_arguments)]
915 fn prepare_img2img_latents(
916 &self,
917 vae: &stable_diffusion::vae::AutoEncoderKL,
918 source_bytes: &[u8],
919 width: u32,
920 height: u32,
921 strength: f64,
922 steps: u32,
923 sched: Scheduler,
924 seed: u64,
925 device: &Device,
926 dtype: DType,
927 vae_dtype: DType,
928 ) -> Result<(Tensor, usize, Tensor, Tensor)> {
929 use crate::img_utils::{decode_source_image, NormalizeRange};
930 let cache_key = image_size_cache_key(source_bytes, width, height);
931 let (encoded, cache_hit) = get_or_insert_cached_tensor(
932 &self.source_latent_cache,
933 cache_key,
934 device,
935 dtype,
936 || {
937 self.base
938 .progress
939 .stage_start("Encoding source image (VAE)");
940 let encode_start = Instant::now();
941
942 let source_tensor = decode_source_image(
943 source_bytes,
944 width,
945 height,
946 NormalizeRange::MinusOneToOne,
947 device,
948 vae_dtype,
949 )?;
950
951 let encoded = vae.encode(&source_tensor)?;
952 let encoded = (encoded.mode()? * VAE_SCALE)?;
953 let encoded = encoded.to_dtype(dtype)?;
956
957 self.base
958 .progress
959 .stage_done("Encoding source image (VAE)", encode_start.elapsed());
960 Ok(encoded)
961 },
962 )?;
963 if cache_hit {
964 self.base.progress.cache_hit("source image latents");
965 }
966
967 let start_step = crate::img2img::img2img_start_index(steps as usize, strength);
968
969 let scheduler = crate::scheduler::build_scheduler(
971 sched,
972 steps as usize,
973 PredictionType::Epsilon,
974 false,
975 )?;
976 let timesteps = scheduler.timesteps().to_vec();
977
978 let latent_h = height as usize / 8;
979 let latent_w = width as usize / 8;
980 let noise =
981 crate::engine::seeded_randn(seed, &[1, 4, latent_h, latent_w], device, DType::F32)?;
982 let noise = noise.to_dtype(dtype)?;
983
984 let noised = if start_step < timesteps.len() {
986 scheduler.add_noise(&encoded, noise.clone(), timesteps[start_step])?
987 } else {
988 encoded.clone()
989 };
990
991 tracing::info!(
992 start_step,
993 total_steps = steps,
994 strength,
995 "img2img: starting from step {start_step}"
996 );
997
998 Ok((noised, start_step, encoded, noise))
999 }
1000
1001 #[allow(clippy::too_many_arguments)]
1003 #[allow(clippy::too_many_arguments)]
1004 fn encode_prompt(
1005 &self,
1006 clip: &stable_diffusion::clip::ClipTextTransformer,
1007 tokenizer: &tokenizers::Tokenizer,
1008 prompt: &str,
1009 negative_prompt: &str,
1010 max_len: usize,
1011 device: &Device,
1012 clip_device: &Device,
1013 dtype: DType,
1014 guidance: f64,
1015 ) -> Result<Tensor> {
1016 let cache_key = cfg_prompt_cache_key(prompt, negative_prompt, guidance);
1022 let (text_embeddings, cache_hit) =
1023 get_or_insert_cached_tensor(&self.prompt_cache, cache_key, device, dtype, || {
1024 let use_cfg = cfg_active(guidance);
1025
1026 self.base.progress.stage_start("Encoding prompt (CLIP-L)");
1027 let encode_start = Instant::now();
1028 let tokens = Self::tokenize(tokenizer, prompt, max_len, clip_device)?;
1029 let text_embeddings = clip.forward(&tokens)?;
1030 self.base
1031 .progress
1032 .stage_done("Encoding prompt (CLIP-L)", encode_start.elapsed());
1033
1034 let text_embeddings = if use_cfg {
1035 let uncond_tokens =
1036 Self::tokenize(tokenizer, negative_prompt, max_len, clip_device)?;
1037 let uncond_embeddings = clip.forward(&uncond_tokens)?;
1038 Tensor::cat(&[&uncond_embeddings, &text_embeddings], 0)?
1039 } else {
1040 text_embeddings
1041 };
1042
1043 let text_embeddings = text_embeddings.to_device(device)?;
1044 Ok(text_embeddings.to_dtype(dtype)?)
1045 })?;
1046 if cache_hit {
1047 self.base.progress.cache_hit("prompt conditioning");
1048 return Ok(text_embeddings);
1049 }
1050 Ok(text_embeddings)
1051 }
1052
1053 fn cached_mask(
1054 &self,
1055 mask_bytes: &[u8],
1056 latent_h: usize,
1057 latent_w: usize,
1058 device: &Device,
1059 dtype: DType,
1060 ) -> Result<Tensor> {
1061 let key = latent_size_cache_key(mask_bytes, latent_h, latent_w);
1062 let (mask, cache_hit) =
1063 get_or_insert_cached_tensor(&self.mask_cache, key, device, dtype, || {
1064 crate::img_utils::decode_mask_image(mask_bytes, latent_h, latent_w, device, dtype)
1065 })?;
1066 if cache_hit {
1067 self.base.progress.cache_hit("inpaint mask");
1068 return Ok(mask);
1069 }
1070 Ok(mask)
1071 }
1072
1073 fn load_controlnet(
1075 &self,
1076 req: &GenerateRequest,
1077 device: &Device,
1078 dtype: DType,
1079 ) -> Result<Option<ControlNetContext>> {
1080 let (control_bytes, control_model_name, scale) =
1081 match (req.control_image.as_ref(), req.control_model.as_ref()) {
1082 (Some(bytes), Some(name)) => (bytes, name, req.control_scale),
1083 _ => return Ok(None),
1084 };
1085
1086 self.base.progress.stage_start("Loading ControlNet");
1087 let cn_start = Instant::now();
1088
1089 let config = mold_core::Config::load_or_default();
1091 let cn_paths =
1092 mold_core::ModelPaths::resolve(control_model_name, &config).ok_or_else(|| {
1093 anyhow::anyhow!(
1094 "ControlNet model '{}' not found. Pull it with: mold pull {}",
1095 control_model_name,
1096 control_model_name,
1097 )
1098 })?;
1099
1100 let unet_config =
1102 candle_transformers::models::stable_diffusion::unet_2d::UNet2DConditionModelConfig::default();
1103 let model = ControlNetModel::load(
1104 &cn_paths.transformer,
1105 device,
1106 dtype,
1107 unet_config,
1108 &self.base.progress,
1109 )?;
1110
1111 self.base
1112 .progress
1113 .stage_done("Loading ControlNet", cn_start.elapsed());
1114
1115 self.base
1117 .progress
1118 .stage_start("Preprocessing control image");
1119 let preprocess_start = Instant::now();
1120
1121 let control_key = image_size_cache_key(control_bytes, req.width, req.height);
1122 let (control_tensor, cache_hit) = get_or_insert_cached_tensor(
1123 &self.control_tensor_cache,
1124 control_key,
1125 device,
1126 dtype,
1127 || {
1128 crate::img_utils::decode_source_image(
1129 control_bytes,
1130 req.width,
1131 req.height,
1132 crate::img_utils::NormalizeRange::ZeroToOne,
1133 device,
1134 dtype,
1135 )
1136 },
1137 )?;
1138 if cache_hit {
1139 self.base.progress.cache_hit("control preprocessing");
1140 }
1141
1142 self.base
1143 .progress
1144 .stage_done("Preprocessing control image", preprocess_start.elapsed());
1145
1146 tracing::info!(
1147 control_model = %control_model_name,
1148 scale,
1149 "ControlNet loaded and control image preprocessed"
1150 );
1151
1152 Ok(Some(ControlNetContext {
1153 model,
1154 control_tensor,
1155 scale,
1156 }))
1157 }
1158
1159 fn generate_sequential(&mut self, req: &GenerateRequest) -> Result<GenerateResponse> {
1166 let (clip_encoder, clip_tokenizer) = self.validate_paths()?;
1167
1168 if let Some(warning) = check_memory_budget(&self.base.paths, LoadStrategy::Sequential) {
1170 self.base.progress.info(&warning);
1171 }
1172
1173 let device = crate::device::create_device(self.base.gpu_ordinal, &self.base.progress)?;
1174 let dtype = if crate::device::is_gpu(&device) {
1175 DType::F16
1176 } else {
1177 DType::F32
1178 };
1179
1180 let sd_config = self.sd_config();
1181 let max_len = sd_config.clip.max_position_embeddings;
1182
1183 let start = Instant::now();
1184 let seed = req.seed.unwrap_or_else(rand_seed);
1185
1186 let width = req.width as usize;
1187 let height = req.height as usize;
1188 let guidance = req.guidance;
1189
1190 tracing::info!(
1191 prompt = %req.prompt,
1192 seed, width, height,
1193 steps = req.steps,
1194 guidance,
1195 "starting sequential SD1.5 generation"
1196 );
1197
1198 self.base
1199 .progress
1200 .info("Using sequential loading (load-use-drop) to minimize peak memory");
1201
1202 let neg = req.negative_prompt.as_deref().unwrap_or("");
1204 let cache_key = cfg_prompt_cache_key(&req.prompt, neg, guidance);
1205 let text_embeddings = if let Some(tensor) =
1206 restore_cached_tensor(&self.prompt_cache, &cache_key, &device, dtype)?
1207 {
1208 self.base.progress.cache_hit("prompt conditioning");
1209 tensor
1210 } else {
1211 if let Some(status) = memory_status_string() {
1212 self.base.progress.info(&status);
1213 }
1214
1215 let tokenizer = self.load_clip_tokenizer(&clip_tokenizer)?;
1216
1217 let tier1 = self
1218 .pending_placement
1219 .as_ref()
1220 .map(|p| p.text_encoders)
1221 .unwrap_or_default();
1222 let clip_device = crate::device::resolve_device(Some(tier1), || Ok(device.clone()))?;
1223 let clip = if let Some(single_file) = self.single_file_path.clone() {
1229 let remap = Self::load_sd15_remap(&single_file)?;
1230 self.base
1231 .progress
1232 .stage_start("Loading CLIP-L (single-file)");
1233 let clip_start = Instant::now();
1234 let clip = Self::build_clip_single_file(
1235 &single_file,
1236 &remap,
1237 &sd_config.clip,
1238 &clip_device,
1239 )?;
1240 self.base
1241 .progress
1242 .stage_done("Loading CLIP-L (single-file)", clip_start.elapsed());
1243 clip
1244 } else {
1245 self.base.progress.stage_start("Loading CLIP-L encoder");
1246 let clip_start = Instant::now();
1247 let clip = stable_diffusion::build_clip_transformer(
1248 &sd_config.clip,
1249 &clip_encoder,
1250 &clip_device,
1251 DType::F32,
1252 )?;
1253 self.base
1254 .progress
1255 .stage_done("Loading CLIP-L encoder", clip_start.elapsed());
1256 clip
1257 };
1258
1259 let text_embeddings = self.encode_prompt(
1260 &clip,
1261 &tokenizer,
1262 &req.prompt,
1263 neg,
1264 max_len,
1265 &device,
1266 &clip_device,
1267 dtype,
1268 guidance,
1269 )?;
1270
1271 drop(clip);
1272 self.base.progress.info("Freed CLIP-L encoder");
1273 tracing::info!("CLIP encoder dropped (sequential mode)");
1274
1275 text_embeddings
1276 };
1277
1278 let unet_size = std::fs::metadata(&self.base.paths.transformer)
1280 .map(|m| m.len())
1281 .unwrap_or(0);
1282 let unet_batch = if req.guidance > 1.0 { 2 } else { 1 };
1284 let unet_activation_budget = crate::device::activation_bytes(
1285 req.width,
1286 req.height,
1287 unet_batch,
1288 crate::device::dtype_bytes(dtype),
1289 crate::device::ActivationFamily::SdxlUnet,
1290 );
1291 preflight_memory_check("UNet", unet_size, unet_activation_budget)?;
1292 if let Some(status) = memory_status_string() {
1293 self.base.progress.info(&status);
1294 }
1295
1296 self.base.progress.stage_start("Loading UNet (GPU)");
1297 let unet_start = Instant::now();
1298 let unet = self.build_unet_for_strategy(&sd_config, &device, dtype)?;
1299 self.base
1300 .progress
1301 .stage_done("Loading UNet (GPU)", unet_start.elapsed());
1302
1303 let sched = req.scheduler.unwrap_or(self.scheduler);
1304 let is_img2img = req.source_image.is_some();
1305
1306 let (mut latents, start_step, inpaint_ctx) = if let Some(ref source_bytes) =
1308 req.source_image
1309 {
1310 self.base
1311 .progress
1312 .info("img2img mode: encoding source image before denoising");
1313
1314 self.base.progress.stage_start("Loading VAE (GPU)");
1316 let vae_start = Instant::now();
1317 let vae_dtype = crate::device::resolve_vae_dtype(dtype);
1318 let vae = self.build_vae_for_strategy(&sd_config, &device, vae_dtype)?;
1319 self.base
1320 .progress
1321 .stage_done("Loading VAE (GPU)", vae_start.elapsed());
1322
1323 let (latents, start_step, encoded, noise) = self.prepare_img2img_latents(
1324 &vae,
1325 source_bytes,
1326 req.width,
1327 req.height,
1328 req.strength,
1329 req.steps,
1330 sched,
1331 seed,
1332 &device,
1333 dtype,
1334 vae_dtype,
1335 )?;
1336
1337 let inpaint_ctx = if let Some(ref mask_bytes) = req.mask_image {
1339 let mask = self.cached_mask(mask_bytes, height / 8, width / 8, &device, dtype)?;
1340 Some(crate::img_utils::InpaintContext {
1341 original_latents: encoded,
1342 mask,
1343 noise,
1344 })
1345 } else {
1346 None
1347 };
1348
1349 drop(vae);
1350 self.base
1351 .progress
1352 .info("Freed VAE (will reload for decode)");
1353 device.synchronize()?;
1354
1355 (latents, start_step, inpaint_ctx)
1356 } else {
1357 let latent_h = height / 8;
1358 let latent_w = width / 8;
1359 let init_scheduler = crate::scheduler::build_scheduler(
1360 sched,
1361 req.steps as usize,
1362 PredictionType::Epsilon,
1363 false,
1364 )?;
1365 let init_noise_sigma = init_scheduler.init_noise_sigma();
1366 drop(init_scheduler);
1367 let latents = (crate::engine::seeded_randn(
1368 seed,
1369 &[1, 4, latent_h, latent_w],
1370 &device,
1371 DType::F32,
1372 )? * init_noise_sigma)?;
1373 (latents.to_dtype(dtype)?, 0, None)
1374 };
1375
1376 let controlnet_ctx = self.load_controlnet(req, &device, dtype)?;
1378
1379 self.denoise_loop(
1380 &unet,
1381 &text_embeddings,
1382 sched,
1383 &mut latents,
1384 guidance,
1385 resolve_cfg_plus(req),
1386 req.steps,
1387 start_step,
1388 inpaint_ctx.as_ref(),
1389 controlnet_ctx.as_ref(),
1390 )?;
1391
1392 drop(controlnet_ctx);
1394 drop(inpaint_ctx);
1395 drop(unet);
1396 drop(text_embeddings);
1397 device.synchronize()?;
1398 self.base.progress.info("Freed UNet");
1399 tracing::info!("UNet dropped (sequential mode)");
1400
1401 let vae_load_label = if is_img2img {
1404 "Reloading VAE (GPU)"
1405 } else {
1406 "Loading VAE (GPU)"
1407 };
1408 self.base.progress.stage_start(vae_load_label);
1409 let vae_start = Instant::now();
1410 let vae_dtype = crate::device::resolve_vae_dtype(dtype);
1411 let vae = self.build_vae_for_strategy(&sd_config, &device, vae_dtype)?;
1412 self.base
1413 .progress
1414 .stage_done(vae_load_label, vae_start.elapsed());
1415
1416 self.base.progress.stage_start("VAE decode");
1417 let vae_decode_start = Instant::now();
1418
1419 let latents = (latents / VAE_SCALE)?;
1420 let latents_for_vae = latents.to_dtype(vae_dtype)?;
1421 let device_for_sync = device.clone();
1422 let img = crate::vae_tiling::decode_with_oom_fallback(
1423 &latents_for_vae,
1424 |t| vae.decode(t).map_err(Into::into),
1425 || {
1426 if let Err(e) = device_for_sync.synchronize() {
1427 tracing::warn!(
1428 "SD1.5 (sequential) device.synchronize() after VAE OOM failed: {e}"
1429 );
1430 }
1431 },
1432 )?;
1433
1434 let img = ((img / 2.)? + 0.5)?.clamp(0f32, 1f32)?;
1435 let img = (img * 255.)?.to_dtype(DType::U8)?;
1436 let img = img.squeeze(0)?;
1437
1438 self.base
1439 .progress
1440 .stage_done("VAE decode", vae_decode_start.elapsed());
1441
1442 let output_metadata = build_output_metadata(req, seed, Some(sched));
1443 let image_bytes = encode_image(
1444 &img,
1445 req.resolved_output_format(),
1446 req.width,
1447 req.height,
1448 output_metadata.as_ref(),
1449 )?;
1450
1451 let generation_time_ms = start.elapsed().as_millis() as u64;
1452 tracing::info!(
1453 generation_time_ms,
1454 seed,
1455 "sequential SD1.5 generation complete"
1456 );
1457
1458 Ok(GenerateResponse {
1459 images: vec![ImageData {
1460 data: image_bytes,
1461 format: req.resolved_output_format(),
1462 width: req.width,
1463 height: req.height,
1464 index: 0,
1465 }],
1466 generation_time_ms,
1467 model: req.model.clone(),
1468 seed_used: seed,
1469 video: None,
1470 gpu: None,
1471 })
1472 }
1473}
1474
1475impl SD15Engine {
1476 fn generate_inner(&mut self, req: &GenerateRequest) -> Result<GenerateResponse> {
1477 if self.base.load_strategy == LoadStrategy::Sequential {
1479 return self.generate_sequential(req);
1480 }
1481
1482 let requested_stack = lora_stack_fingerprint(&self.pending_loras);
1486 if requested_stack != self.active_lora_fingerprint {
1487 if let Some(loaded) = self.base.loaded.as_mut() {
1488 if loaded.unet.is_some() {
1489 loaded.unet = None;
1490 loaded.device.synchronize()?;
1491 tracing::info!("SD1.5 UNet dropped (LoRA stack changed)");
1492 }
1493 }
1494 self.active_lora_fingerprint = requested_stack;
1495 }
1496
1497 self.reload_unet_if_needed()?;
1499
1500 let loaded = self
1501 .base
1502 .loaded
1503 .as_ref()
1504 .ok_or_else(|| anyhow::anyhow!("model not loaded — call load() first"))?;
1505
1506 let start = Instant::now();
1507 let seed = req.seed.unwrap_or_else(rand_seed);
1508
1509 let width = req.width as usize;
1510 let height = req.height as usize;
1511 let guidance = req.guidance;
1512
1513 tracing::info!(
1514 prompt = %req.prompt,
1515 seed, width, height,
1516 steps = req.steps,
1517 guidance,
1518 scheduler = %self.scheduler,
1519 "starting SD1.5 generation"
1520 );
1521
1522 let max_len = loaded.sd_config.clip.max_position_embeddings;
1524 let neg = req.negative_prompt.as_deref().unwrap_or("");
1525 let text_embeddings = self.encode_prompt(
1526 &loaded.clip,
1527 &loaded.tokenizer,
1528 &req.prompt,
1529 neg,
1530 max_len,
1531 &loaded.device,
1532 &loaded.clip_device,
1533 loaded.dtype,
1534 guidance,
1535 )?;
1536
1537 let sched = req.scheduler.unwrap_or(self.scheduler);
1539
1540 let (mut latents, start_step, inpaint_ctx) =
1541 if let Some(ref source_bytes) = req.source_image {
1542 let (latents, start_step, encoded, noise) = self.prepare_img2img_latents(
1543 &loaded.vae,
1544 source_bytes,
1545 req.width,
1546 req.height,
1547 req.strength,
1548 req.steps,
1549 sched,
1550 seed,
1551 &loaded.device,
1552 loaded.dtype,
1553 loaded.vae_dtype,
1554 )?;
1555 let inpaint_ctx = if let Some(ref mask_bytes) = req.mask_image {
1556 let mask = self.cached_mask(
1557 mask_bytes,
1558 height / 8,
1559 width / 8,
1560 &loaded.device,
1561 loaded.dtype,
1562 )?;
1563 Some(crate::img_utils::InpaintContext {
1564 original_latents: encoded,
1565 mask,
1566 noise,
1567 })
1568 } else {
1569 None
1570 };
1571 (latents, start_step, inpaint_ctx)
1572 } else {
1573 let latent_h = height / 8;
1574 let latent_w = width / 8;
1575 let init_scheduler = crate::scheduler::build_scheduler(
1576 sched,
1577 req.steps as usize,
1578 PredictionType::Epsilon,
1579 false,
1580 )?;
1581 let init_noise_sigma = init_scheduler.init_noise_sigma();
1582 drop(init_scheduler);
1583 let latents = (crate::engine::seeded_randn(
1584 seed,
1585 &[1, 4, latent_h, latent_w],
1586 &loaded.device,
1587 DType::F32,
1588 )? * init_noise_sigma)?;
1589 (latents.to_dtype(loaded.dtype)?, 0, None)
1590 };
1591
1592 let controlnet_ctx = self.load_controlnet(req, &loaded.device, loaded.dtype)?;
1594
1595 let unet = loaded
1597 .unet
1598 .as_ref()
1599 .ok_or_else(|| anyhow::anyhow!("UNet not loaded"))?;
1600 self.denoise_loop(
1601 unet,
1602 &text_embeddings,
1603 sched,
1604 &mut latents,
1605 guidance,
1606 resolve_cfg_plus(req),
1607 req.steps,
1608 start_step,
1609 inpaint_ctx.as_ref(),
1610 controlnet_ctx.as_ref(),
1611 )?;
1612
1613 drop(controlnet_ctx);
1615 drop(inpaint_ctx);
1616 let _ = loaded;
1618 let loaded = self.base.loaded.as_mut().unwrap();
1619 loaded.unet = None;
1620 loaded.device.synchronize()?;
1621 tracing::info!("UNet dropped to free VRAM for VAE decode");
1622 let _ = loaded;
1624 let loaded = self.base.loaded.as_ref().unwrap();
1625
1626 self.base.progress.stage_start("VAE decode");
1628 let vae_start = Instant::now();
1629
1630 let latents = (latents / VAE_SCALE)?;
1631 let latents_for_vae = latents.to_dtype(loaded.vae_dtype)?;
1632 let vae = &loaded.vae;
1633 let device_for_sync = loaded.device.clone();
1634 let img = crate::vae_tiling::decode_with_oom_fallback(
1635 &latents_for_vae,
1636 |t| vae.decode(t).map_err(Into::into),
1637 || {
1638 if let Err(e) = device_for_sync.synchronize() {
1639 tracing::warn!(
1640 "SD1.5 (parallel) device.synchronize() after VAE OOM failed: {e}"
1641 );
1642 }
1643 },
1644 )?;
1645
1646 let img = ((img / 2.)? + 0.5)?.clamp(0f32, 1f32)?;
1647 let img = (img * 255.)?.to_dtype(DType::U8)?;
1648 let img = img.squeeze(0)?;
1649
1650 self.base
1651 .progress
1652 .stage_done("VAE decode", vae_start.elapsed());
1653
1654 let output_metadata = build_output_metadata(req, seed, Some(sched));
1656 let image_bytes = encode_image(
1657 &img,
1658 req.resolved_output_format(),
1659 req.width,
1660 req.height,
1661 output_metadata.as_ref(),
1662 )?;
1663
1664 let generation_time_ms = start.elapsed().as_millis() as u64;
1665 tracing::info!(generation_time_ms, seed, "SD1.5 generation complete");
1666
1667 Ok(GenerateResponse {
1668 images: vec![ImageData {
1669 data: image_bytes,
1670 format: req.resolved_output_format(),
1671 width: req.width,
1672 height: req.height,
1673 index: 0,
1674 }],
1675 generation_time_ms,
1676 model: req.model.clone(),
1677 seed_used: seed,
1678 video: None,
1679 gpu: None,
1680 })
1681 }
1682}
1683
1684impl InferenceEngine for SD15Engine {
1685 fn generate(&mut self, req: &GenerateRequest) -> Result<GenerateResponse> {
1686 self.pending_placement = req.placement.clone();
1687 self.pending_loras = super::lora::effective_sd15_loras(req);
1688 let result = self.generate_inner(req);
1689 self.pending_placement = None;
1690 self.pending_loras.clear();
1691 result
1692 }
1693
1694 fn model_name(&self) -> &str {
1695 self.base.model_name()
1696 }
1697
1698 fn is_loaded(&self) -> bool {
1699 self.base.is_loaded()
1700 }
1701
1702 fn load(&mut self) -> Result<()> {
1703 SD15Engine::load(self)
1704 }
1705
1706 fn unload(&mut self) {
1707 self.base.unload();
1708 clear_cache(&self.prompt_cache);
1709 clear_cache(&self.source_latent_cache);
1710 clear_cache(&self.mask_cache);
1711 clear_cache(&self.control_tensor_cache);
1712 self.active_lora_fingerprint.clear();
1713 }
1714
1715 fn set_on_progress(&mut self, callback: ProgressCallback) {
1716 self.base.set_on_progress(callback);
1717 }
1718
1719 fn clear_on_progress(&mut self) {
1720 self.base.clear_on_progress();
1721 }
1722
1723 fn model_paths(&self) -> Option<&mold_core::ModelPaths> {
1724 Some(&self.base.paths)
1725 }
1726}
1727
1728#[cfg(test)]
1729mod tests {
1730 use super::*;
1731 use crate::engine::InferenceEngine;
1732 use crate::shared_pool::SharedPool;
1733 use safetensors::tensor::{serialize_to_file, Dtype as SafeDtype, TensorView};
1734 use std::collections::HashMap;
1735 use std::path::PathBuf;
1736 use std::sync::{Arc, Mutex};
1737 use tokenizers::models::bpe::BPE;
1738
1739 fn synth_sd15_single_file(name: &str) -> PathBuf {
1744 let path = std::env::temp_dir().join(format!(
1745 "mold-sd15-from-sf-{}-{}-{}.safetensors",
1746 name,
1747 std::process::id(),
1748 std::time::SystemTime::now()
1749 .duration_since(std::time::UNIX_EPOCH)
1750 .unwrap()
1751 .as_nanos(),
1752 ));
1753
1754 let keys: &[&str] = &[
1755 "model.diffusion_model.input_blocks.0.0.weight",
1757 "model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight",
1758 "first_stage_model.encoder.down.0.block.0.norm1.weight",
1760 "first_stage_model.quant_conv.weight",
1761 "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.q_proj.weight",
1763 "cond_stage_model.transformer.text_model.final_layer_norm.weight",
1764 ];
1765
1766 let f32_zero = 0.0f32.to_le_bytes().to_vec();
1767 let buffers: Vec<Vec<u8>> = keys.iter().map(|_| f32_zero.clone()).collect();
1768 let mut tensors: HashMap<String, TensorView<'_>> = HashMap::new();
1769 for (key, buf) in keys.iter().zip(buffers.iter()) {
1770 tensors.insert(
1771 (*key).to_string(),
1772 TensorView::new(SafeDtype::F32, vec![1], buf).unwrap(),
1773 );
1774 }
1775 serialize_to_file(&tensors, &None, &path).unwrap();
1776 path
1777 }
1778
1779 #[test]
1780 fn from_single_file_constructs_for_synthetic_sd15_checkpoint() {
1781 let single_file = synth_sd15_single_file("ok");
1782 let tokenizer_path = std::env::temp_dir().join("mold-sd15-tok-stub.json");
1785
1786 let engine = SD15Engine::from_single_file(
1787 "dreamshaper-8".to_string(),
1788 single_file.clone(),
1789 tokenizer_path,
1790 Scheduler::default(),
1791 LoadStrategy::Eager,
1792 0,
1793 None,
1794 )
1795 .expect("constructor must accept a valid SD1.5 single-file layout");
1796
1797 assert_eq!(engine.model_name(), "dreamshaper-8");
1798 assert_eq!(
1799 engine.single_file_path.as_deref(),
1800 Some(single_file.as_path()),
1801 "single-file path must be stashed for the future load() branch",
1802 );
1803 assert!(
1804 !engine.is_loaded(),
1805 "constructor must not eagerly materialise model weights",
1806 );
1807
1808 let _ = std::fs::remove_file(single_file);
1809 }
1810
1811 #[test]
1812 fn sd15_loads_clip_tokenizer_through_shared_pool() {
1813 let dir = tempfile::tempdir().unwrap();
1814 let tokenizer_path = dir.path().join("clip-tokenizer.json");
1815 tokenizers::Tokenizer::new(BPE::default())
1816 .save(&tokenizer_path, false)
1817 .unwrap();
1818 let weights_path = dir.path().join("weights.safetensors");
1819 std::fs::write(&weights_path, b"stub").unwrap();
1820
1821 let shared_pool = Arc::new(Mutex::new(SharedPool::new()));
1822 let pooled = shared_pool
1823 .lock()
1824 .unwrap()
1825 .load_tokenizer(&tokenizer_path)
1826 .unwrap();
1827
1828 let paths = ModelPaths {
1829 transformer: weights_path.clone(),
1830 transformer_shards: Vec::new(),
1831 vae: weights_path.clone(),
1832 spatial_upscaler: None,
1833 temporal_upscaler: None,
1834 distilled_lora: None,
1835 t5_encoder: None,
1836 clip_encoder: Some(weights_path),
1837 t5_tokenizer: None,
1838 clip_tokenizer: Some(tokenizer_path.clone()),
1839 clip_encoder_2: None,
1840 clip_tokenizer_2: None,
1841 text_encoder_files: Vec::new(),
1842 text_tokenizer: None,
1843 decoder: None,
1844 };
1845 let engine = SD15Engine::new(
1846 "sd15-test".to_string(),
1847 paths,
1848 Scheduler::default(),
1849 LoadStrategy::Eager,
1850 0,
1851 Some(shared_pool),
1852 );
1853
1854 let loaded = engine.load_clip_tokenizer(&tokenizer_path).unwrap();
1855
1856 assert!(Arc::ptr_eq(&pooled, &loaded));
1857 }
1858
1859 #[test]
1860 fn sd15_loads_vae_tensors_through_shared_pool() {
1861 let dir = tempfile::tempdir().unwrap();
1862 let vae_path = dir.path().join("vae.safetensors");
1863 let weight = 1.0f32.to_le_bytes();
1864 let mut tensors = HashMap::new();
1865 tensors.insert(
1866 "encoder.conv_in.weight".to_string(),
1867 TensorView::new(SafeDtype::F32, vec![1], &weight).unwrap(),
1868 );
1869 serialize_to_file(&tensors, &None, &vae_path).unwrap();
1870
1871 let shared_pool = Arc::new(Mutex::new(SharedPool::new()));
1872 let pooled = shared_pool
1873 .lock()
1874 .unwrap()
1875 .load_safetensors_cpu_tensors(std::slice::from_ref(&vae_path))
1876 .unwrap()
1877 .unwrap();
1878
1879 let paths = ModelPaths {
1880 transformer: dir.path().join("unet.safetensors"),
1881 transformer_shards: Vec::new(),
1882 vae: vae_path.clone(),
1883 spatial_upscaler: None,
1884 temporal_upscaler: None,
1885 distilled_lora: None,
1886 t5_encoder: None,
1887 clip_encoder: Some(dir.path().join("clip.safetensors")),
1888 t5_tokenizer: None,
1889 clip_tokenizer: Some(dir.path().join("clip-tokenizer.json")),
1890 clip_encoder_2: None,
1891 clip_tokenizer_2: None,
1892 text_encoder_files: Vec::new(),
1893 text_tokenizer: None,
1894 decoder: None,
1895 };
1896 let engine = SD15Engine::new(
1897 "sd15-test".to_string(),
1898 paths,
1899 Scheduler::default(),
1900 LoadStrategy::Eager,
1901 0,
1902 Some(shared_pool),
1903 );
1904
1905 let loaded = engine.load_vae_cpu_tensors().unwrap().unwrap();
1906
1907 assert!(Arc::ptr_eq(&pooled, &loaded));
1908 }
1909
1910 #[test]
1911 fn load_branches_to_single_file_path_and_invokes_candle_constructors() {
1912 let single_file = synth_sd15_single_file("load-branch");
1922 let tokenizer_stub = std::env::temp_dir().join(format!(
1923 "mold-sd15-tok-stub-{}-{}.json",
1924 std::process::id(),
1925 std::time::SystemTime::now()
1926 .duration_since(std::time::UNIX_EPOCH)
1927 .unwrap()
1928 .as_nanos(),
1929 ));
1930 std::fs::write(&tokenizer_stub, b"").unwrap();
1931
1932 let mut engine = SD15Engine::from_single_file(
1933 "dreamshaper-8".to_string(),
1934 single_file.clone(),
1935 tokenizer_stub.clone(),
1936 Scheduler::Ddim,
1937 LoadStrategy::Eager,
1938 0,
1939 None,
1940 )
1941 .expect("constructor");
1942
1943 std::env::set_var("MOLD_DEVICE", "cpu");
1944 let err = SD15Engine::load(&mut engine).expect_err(
1945 "synthetic single-file checkpoint can't satisfy SD1.5's full tensor set; \
1946 dispatch should land in the candle constructor and fail there",
1947 );
1948 std::env::remove_var("MOLD_DEVICE");
1949
1950 let msg = err.to_string();
1951 assert!(
1957 msg.contains("single-file") || msg.contains("rename rule"),
1958 "expected error from the single-file load layer, got: {msg}",
1959 );
1960
1961 let _ = std::fs::remove_file(single_file);
1962 let _ = std::fs::remove_file(tokenizer_stub);
1963 }
1964
1965 #[test]
1973 fn build_clip_single_file_dispatches_through_backend_not_diffusers_loader() {
1974 let single_file = synth_sd15_single_file("seq-clip-dispatch");
1975 let remap = SD15Engine::load_sd15_remap(&single_file).expect("remap");
1976
1977 let result = SD15Engine::build_clip_single_file(
1978 &single_file,
1979 &remap,
1980 &stable_diffusion::clip::Config::v1_5(),
1981 &Device::Cpu,
1982 );
1983
1984 let err = result.expect_err("synthetic CLIP-L is incomplete");
1985 let msg = err.to_string();
1986 assert!(
1987 !msg.contains("cannot find tensor text_model"),
1988 "expected failure from SingleFileBackend ('no rename rule for diffusers key …'), \
1989 not from candle's diffusers `from_mmaped_safetensors`. Sequential dispatch \
1990 is still routing through `build_clip_transformer`. Got: {msg}",
1991 );
1992
1993 let _ = std::fs::remove_file(single_file);
1994 }
1995
1996 #[test]
1999 fn build_unet_single_file_dispatches_through_backend_not_diffusers_loader() {
2000 let single_file = synth_sd15_single_file("seq-unet-dispatch");
2001 let remap = SD15Engine::load_sd15_remap(&single_file).expect("remap");
2002
2003 let result = SD15Engine::build_unet_single_file(
2004 &single_file,
2005 &remap,
2006 &stable_diffusion::StableDiffusionConfig::v1_5(None, None, None),
2007 &Device::Cpu,
2008 DType::F32,
2009 );
2010
2011 let err = result.expect_err("synthetic UNet is incomplete");
2012 let msg = err.to_string();
2013 assert!(
2014 !msg.contains("cannot find tensor conv_in"),
2015 "expected failure from SingleFileBackend, not diffusers loader. Got: {msg}",
2016 );
2017
2018 let _ = std::fs::remove_file(single_file);
2019 }
2020
2021 #[test]
2030 #[ignore]
2031 fn from_single_file_real_shape_load_smoke() {
2032 }
2036
2037 #[test]
2038 fn from_single_file_rejects_missing_file() {
2039 let bogus = std::env::temp_dir().join(format!(
2040 "mold-sd15-from-sf-missing-{}-{}.safetensors",
2041 std::process::id(),
2042 std::time::SystemTime::now()
2043 .duration_since(std::time::UNIX_EPOCH)
2044 .unwrap()
2045 .as_nanos(),
2046 ));
2047
2048 let result = SD15Engine::from_single_file(
2049 "missing".to_string(),
2050 bogus,
2051 std::env::temp_dir().join("mold-sd15-tok-stub.json"),
2052 Scheduler::default(),
2053 LoadStrategy::Eager,
2054 0,
2055 None,
2056 );
2057
2058 assert!(
2059 result.is_err(),
2060 "constructor must surface a missing-file error before deeper parsing",
2061 );
2062 }
2063
2064 #[test]
2070 fn test_cfg_disabled_at_guidance_1_0() {
2071 assert!(!cfg_active(1.0));
2072 }
2073
2074 #[test]
2075 fn test_cfg_disabled_just_below_1_0() {
2076 assert!(!cfg_active(1.0 - 1e-5));
2077 }
2078
2079 #[test]
2080 fn test_cfg_enabled_at_guidance_1_5() {
2081 assert!(cfg_active(1.5));
2082 }
2083
2084 #[test]
2085 fn test_cfg_enabled_at_guidance_7_5() {
2086 assert!(cfg_active(7.5));
2087 }
2088
2089 #[test]
2094 fn sd15_prompt_cache_distinguishes_negative_prompt_changes() {
2095 use crate::cache::{cfg_prompt_cache_key, store_cached_tensor};
2096
2097 let cache: Mutex<LruCache<CfgPromptCacheKey, CachedTensor>> =
2098 Mutex::new(LruCache::new(DEFAULT_PROMPT_CACHE_CAPACITY));
2099 let device = Device::Cpu;
2100 let dtype = DType::F32;
2101 let embeddings = candle_core::Tensor::zeros((1, 4), dtype, &device).unwrap();
2102
2103 let key_a = cfg_prompt_cache_key("a cat", "blurry", 7.0);
2104 store_cached_tensor(&cache, key_a.clone(), &embeddings).unwrap();
2105
2106 let key_b = cfg_prompt_cache_key("a cat", "low quality", 7.0);
2108 let restored = restore_cached_tensor(&cache, &key_b, &device, dtype).unwrap();
2109 assert!(
2110 restored.is_none(),
2111 "different negative prompt must miss the cache (silent-wrong-output bug)",
2112 );
2113
2114 let restored = restore_cached_tensor(&cache, &key_a, &device, dtype).unwrap();
2116 assert!(
2117 restored.is_some(),
2118 "identical (pos, neg, guidance) must hit",
2119 );
2120 }
2121
2122 #[test]
2127 fn lora_stack_fingerprint_equality_drives_unet_drop() {
2128 let a = mold_core::LoraWeight {
2129 path: "/a.safetensors".to_string(),
2130 scale: 0.8,
2131 };
2132 let b = mold_core::LoraWeight {
2133 path: "/b.safetensors".to_string(),
2134 scale: 0.4,
2135 };
2136 let same_a = mold_core::LoraWeight {
2137 path: "/a.safetensors".to_string(),
2138 scale: 0.8,
2139 };
2140
2141 assert_eq!(
2143 lora_stack_fingerprint(&[a.clone(), b.clone()]),
2144 lora_stack_fingerprint(&[same_a.clone(), b.clone()])
2145 );
2146
2147 let scaled = mold_core::LoraWeight {
2149 path: "/a.safetensors".to_string(),
2150 scale: 0.81,
2151 };
2152 assert_ne!(
2153 lora_stack_fingerprint(std::slice::from_ref(&a)),
2154 lora_stack_fingerprint(std::slice::from_ref(&scaled))
2155 );
2156
2157 assert_ne!(
2159 lora_stack_fingerprint(&[a.clone(), b.clone()]),
2160 lora_stack_fingerprint(&[b, a])
2161 );
2162 }
2163}