Skip to main content

mold_inference/sd15/
pipeline.rs

1use anyhow::{bail, Result};
2use candle_core::{DType, Device, Module, Tensor};
3use candle_transformers::models::stable_diffusion;
4use candle_transformers::models::stable_diffusion::schedulers::PredictionType;
5use mold_core::{GenerateRequest, GenerateResponse, ImageData, ModelPaths, Scheduler};
6use std::collections::HashMap;
7use std::path::{Path, PathBuf};
8use std::sync::{Arc, Mutex};
9use std::time::Instant;
10
11use crate::cache::{
12    cfg_prompt_cache_key, clear_cache, get_or_insert_cached_tensor, image_size_cache_key,
13    latent_size_cache_key, restore_cached_tensor, CachedTensor, CfgPromptCacheKey,
14    ImageSizeCacheKey, LatentSizeCacheKey, LruCache, DEFAULT_IMAGE_CACHE_CAPACITY,
15    DEFAULT_PROMPT_CACHE_CAPACITY,
16};
17use crate::cfg_plus_ddim::DdimAlphaSchedule;
18use crate::controlnet::ControlNetModel;
19use crate::device::{check_memory_budget, memory_status_string, preflight_memory_check};
20use crate::engine::{cfg_active, rand_seed, resolve_cfg_plus, InferenceEngine, LoadStrategy};
21use crate::engine_base::EngineBase;
22use crate::image::{build_output_metadata, encode_image};
23use crate::progress::{ProgressCallback, ProgressEvent};
24
25/// VAE scaling factor for SD1.5 models.
26const VAE_SCALE: f64 = 0.18215;
27
28/// ControlNet context: loaded model, preprocessed control tensor, and scale.
29struct ControlNetContext {
30    model: ControlNetModel,
31    control_tensor: Tensor,
32    scale: f64,
33}
34
35/// Loaded SD1.5 model components, ready for inference.
36struct LoadedSD15 {
37    /// None after being dropped for VAE decode VRAM; reloaded on next generate.
38    unet: Option<stable_diffusion::unet_2d::UNet2DConditionModel>,
39    vae: stable_diffusion::vae::AutoEncoderKL,
40    clip: stable_diffusion::clip::ClipTextTransformer,
41    tokenizer: Arc<tokenizers::Tokenizer>,
42    sd_config: stable_diffusion::StableDiffusionConfig,
43    device: Device,
44    /// Device the CLIP-L weights live on. May differ from `device` when the
45    /// Tier 1 `text_encoders` placement override pins encoders to CPU.
46    clip_device: Device,
47    dtype: DType,
48    /// Effective VAE dtype after `MOLD_VAE_DTYPE` resolution. May differ from
49    /// `dtype` when fp32 VAE decode is forced. Captured at load time.
50    vae_dtype: DType,
51}
52
53/// SD1.5 inference engine backed by candle's stable_diffusion module.
54///
55/// Simplified variant of SDXL: single CLIP-L encoder, smaller UNet, 512x512 default.
56pub struct SD15Engine {
57    base: EngineBase<LoadedSD15>,
58    scheduler: Scheduler,
59    shared_pool: Option<Arc<Mutex<crate::shared_pool::SharedPool>>>,
60    prompt_cache: Mutex<LruCache<CfgPromptCacheKey, CachedTensor>>,
61    source_latent_cache: Mutex<LruCache<ImageSizeCacheKey, CachedTensor>>,
62    mask_cache: Mutex<LruCache<LatentSizeCacheKey, CachedTensor>>,
63    control_tensor_cache: Mutex<LruCache<ImageSizeCacheKey, CachedTensor>>,
64    pending_placement: Option<mold_core::types::DevicePlacement>,
65    /// `Some(path)` when the engine was built from a Civitai single-file
66    /// checkpoint via `from_single_file()`. `load()` branches on this to
67    /// wire a custom `SingleFileBackend` (translating diffusers `vb.get`
68    /// calls through `Sd15Remap` into mmap'd A1111 reads) instead of
69    /// calling candle's per-component `build_*` helpers. `None` for
70    /// diffusers-layout (HF) checkpoints.
71    pub(crate) single_file_path: Option<PathBuf>,
72    /// Per-request LoRA list, set from `req.lora` / `req.loras` in
73    /// `generate()` and cleared after. Read by `build_unet_for_strategy` to
74    /// decide whether to wrap the UNet's `VarBuilder` with `super::lora`'s
75    /// `wrap_backend_with_lora`. Empty in the no-LoRA hot path. Mirrors the
76    /// SDXL field of the same name.
77    pending_loras: Vec<mold_core::LoraWeight>,
78    /// Fingerprint of the LoRA stack currently merged into the loaded UNet
79    /// (eager mode). When `effective_sd15_loras(req)` produces a different
80    /// fingerprint, `generate_inner` drops the UNet so the next
81    /// `reload_unet_if_needed` call rebuilds it with the new wrapper.
82    /// Empty when no LoRA is active.
83    active_lora_fingerprint: Vec<(String, u64)>,
84}
85
86/// Compute a stable fingerprint for a LoRA stack: ordered list of
87/// `(path, scale_bits)`. Two stacks are equal iff they merge to the same
88/// transformer, so a comparison drives UNet reload on any change (swap,
89/// scale, add, remove, reorder).
90fn lora_stack_fingerprint(loras: &[mold_core::LoraWeight]) -> Vec<(String, u64)> {
91    loras
92        .iter()
93        .map(|w| (w.path.clone(), w.scale.to_bits()))
94        .collect()
95}
96
97impl SD15Engine {
98    pub fn new(
99        model_name: String,
100        paths: ModelPaths,
101        scheduler: Scheduler,
102        load_strategy: LoadStrategy,
103        gpu_ordinal: usize,
104        shared_pool: Option<Arc<Mutex<crate::shared_pool::SharedPool>>>,
105    ) -> Self {
106        Self {
107            base: EngineBase::new(model_name, paths, load_strategy, gpu_ordinal),
108            scheduler,
109            shared_pool,
110            prompt_cache: Mutex::new(LruCache::new(DEFAULT_PROMPT_CACHE_CAPACITY)),
111            source_latent_cache: Mutex::new(LruCache::new(DEFAULT_IMAGE_CACHE_CAPACITY)),
112            mask_cache: Mutex::new(LruCache::new(DEFAULT_IMAGE_CACHE_CAPACITY)),
113            control_tensor_cache: Mutex::new(LruCache::new(DEFAULT_IMAGE_CACHE_CAPACITY)),
114            pending_placement: None,
115            single_file_path: None,
116            pending_loras: Vec::new(),
117            active_lora_fingerprint: Vec::new(),
118        }
119    }
120
121    /// Construct a SD1.5 engine from a Civitai single-file `.safetensors`
122    /// checkpoint.
123    ///
124    /// Header-parses the file via `loader::single_file::load(_,
125    /// Family::Sd15)`, validates that the SD1.5 A1111 → diffusers rename
126    /// rules cover its UNet / VAE / CLIP-L tensor keys via
127    /// `loader::sd15_keys::build_sd15_remap(_)`, and stashes the path in
128    /// `single_file_path`. The actual UNet / VAE / CLIP-L materialisation
129    /// (with a custom `SimpleBackend` that translates each diffusers
130    /// `vb.get(name)` into the corresponding A1111 key inside the mmap'd
131    /// single file) lands in a downstream phase — this constructor does
132    /// **not** build the model.
133    ///
134    /// `clip_tokenizer` is the path to a companion-pulled CLIP-L
135    /// tokenizer (phase 2.7); the tokenizer never lives inside the
136    /// single-file checkpoint.
137    pub fn from_single_file(
138        model_name: String,
139        single_file_path: PathBuf,
140        clip_tokenizer: PathBuf,
141        scheduler: Scheduler,
142        load_strategy: LoadStrategy,
143        gpu_ordinal: usize,
144        shared_pool: Option<Arc<Mutex<crate::shared_pool::SharedPool>>>,
145    ) -> Result<Self> {
146        if !single_file_path.exists() {
147            bail!(
148                "single-file checkpoint not found: {}",
149                single_file_path.display()
150            );
151        }
152
153        // Header-parse and validate the SD1.5 layout. `single_file::load`
154        // is header-only — no tensor data is mmap'd here.
155        let bundle = crate::loader::single_file::load(
156            &single_file_path,
157            mold_catalog::families::Family::Sd15,
158        )?;
159
160        // Validate that every diffusers key the future SimpleBackend will
161        // request resolves to a real A1111 key inside the file. Catches
162        // malformed checkpoints at construction time rather than deep in
163        // the eventual `load()` call.
164        let _remap = crate::loader::sd15_keys::build_sd15_remap(&bundle)?;
165
166        // `ModelPaths` is a diffusers-layout view; for single-file the
167        // transformer / vae / clip_encoder all materialise from the same
168        // checkpoint at load time. The real branch lives on the new
169        // `single_file_path` field — future `load()` consults it before
170        // falling through to the diffusers `build_*` helpers.
171        let paths = ModelPaths {
172            transformer: single_file_path.clone(),
173            transformer_shards: Vec::new(),
174            vae: single_file_path.clone(),
175            spatial_upscaler: None,
176            temporal_upscaler: None,
177            distilled_lora: None,
178            t5_encoder: None,
179            clip_encoder: Some(single_file_path.clone()),
180            t5_tokenizer: None,
181            clip_tokenizer: Some(clip_tokenizer),
182            clip_encoder_2: None,
183            clip_tokenizer_2: None,
184            text_encoder_files: Vec::new(),
185            text_tokenizer: None,
186            decoder: None,
187        };
188
189        Ok(Self {
190            base: EngineBase::new(model_name, paths, load_strategy, gpu_ordinal),
191            scheduler,
192            shared_pool,
193            prompt_cache: Mutex::new(LruCache::new(DEFAULT_PROMPT_CACHE_CAPACITY)),
194            source_latent_cache: Mutex::new(LruCache::new(DEFAULT_IMAGE_CACHE_CAPACITY)),
195            mask_cache: Mutex::new(LruCache::new(DEFAULT_IMAGE_CACHE_CAPACITY)),
196            control_tensor_cache: Mutex::new(LruCache::new(DEFAULT_IMAGE_CACHE_CAPACITY)),
197            pending_placement: None,
198            single_file_path: Some(single_file_path),
199            pending_loras: Vec::new(),
200            active_lora_fingerprint: Vec::new(),
201        })
202    }
203
204    /// Validate and return required SD1.5 paths (CLIP-L encoder + tokenizer).
205    fn validate_paths(&self) -> Result<(std::path::PathBuf, std::path::PathBuf)> {
206        let clip_encoder = self
207            .base
208            .paths
209            .clip_encoder
210            .as_ref()
211            .ok_or_else(|| anyhow::anyhow!("CLIP-L encoder path required for SD1.5 models"))?
212            .clone();
213        let clip_tokenizer = self
214            .base
215            .paths
216            .clip_tokenizer
217            .as_ref()
218            .ok_or_else(|| anyhow::anyhow!("CLIP-L tokenizer path required for SD1.5 models"))?
219            .clone();
220
221        for (label, path) in [
222            ("transformer (UNet)", &self.base.paths.transformer),
223            ("vae", &self.base.paths.vae),
224            ("clip_encoder (CLIP-L)", &clip_encoder),
225            ("clip_tokenizer (CLIP-L)", &clip_tokenizer),
226        ] {
227            if !path.exists() {
228                bail!("{label} file not found: {}", path.display());
229            }
230        }
231
232        Ok((clip_encoder, clip_tokenizer))
233    }
234
235    fn load_clip_tokenizer(
236        &self,
237        clip_tokenizer: &std::path::Path,
238    ) -> Result<Arc<tokenizers::Tokenizer>> {
239        if let Some(ref pool) = self.shared_pool {
240            return pool.lock().unwrap().load_tokenizer(clip_tokenizer);
241        }
242        Ok(Arc::new(
243            tokenizers::Tokenizer::from_file(clip_tokenizer)
244                .map_err(|e| anyhow::anyhow!("failed to load CLIP-L tokenizer: {e}"))?,
245        ))
246    }
247
248    /// Create the SD1.5 config.
249    fn sd_config(&self) -> stable_diffusion::StableDiffusionConfig {
250        stable_diffusion::StableDiffusionConfig::v1_5(None, None, None)
251    }
252
253    /// Reload UNet if it was dropped after VAE decode.
254    fn reload_unet_if_needed(&mut self) -> Result<()> {
255        let needs_reload = self
256            .base
257            .loaded
258            .as_ref()
259            .map(|l| l.unet.is_none())
260            .unwrap_or(false);
261
262        if needs_reload {
263            let sd_config = self.sd_config();
264            let loaded = self.base.loaded.as_ref().unwrap();
265            let device = loaded.device.clone();
266            let dtype = loaded.dtype;
267            let _ = loaded;
268
269            self.base.progress.stage_start("Reloading UNet (GPU)");
270            let reload_start = Instant::now();
271            let unet = self.build_unet_for_strategy(&sd_config, &device, dtype)?;
272            self.base.loaded.as_mut().unwrap().unet = Some(unet);
273            self.base
274                .progress
275                .stage_done("Reloading UNet (GPU)", reload_start.elapsed());
276        }
277        Ok(())
278    }
279
280    /// UNet load that branches on `single_file_path` — Civitai single-file
281    /// checkpoints (A1111 naming) get the `SingleFileBackend` dispatch,
282    /// diffusers-layout falls through to candle's `build_unet`. Used by
283    /// every UNet load site (eager `load`, sequential `generate_sequential`,
284    /// and `reload_unet_if_needed`) so the branch logic lives in exactly
285    /// one place.
286    ///
287    /// When `self.pending_loras` is non-empty, the underlying tensor source
288    /// (mmap'd diffusers safetensors or `SingleFileBackend`) is wrapped in
289    /// `super::lora::wrap_backend_with_lora` so the UNet construction loads
290    /// `W' = W + scale·(B @ A)` for every LoRA-targeted layer. The wrapper
291    /// is transparent to `UNet2DConditionModel::new`.
292    fn build_unet_for_strategy(
293        &self,
294        sd_config: &stable_diffusion::StableDiffusionConfig,
295        device: &Device,
296        dtype: DType,
297    ) -> Result<stable_diffusion::unet_2d::UNet2DConditionModel> {
298        let has_lora = !self.pending_loras.is_empty();
299        if let Some(single_file) = self.single_file_path.as_ref() {
300            let remap = Self::load_sd15_remap(single_file)?;
301            if has_lora {
302                self.build_unet_single_file_with_lora(single_file, &remap, sd_config, device, dtype)
303            } else {
304                Self::build_unet_single_file(single_file, &remap, sd_config, device, dtype)
305            }
306        } else if has_lora {
307            self.build_unet_diffusers_with_lora(sd_config, device, dtype)
308        } else {
309            Ok(sd_config.build_unet(&self.base.paths.transformer, device, 4, false, dtype)?)
310        }
311    }
312
313    /// Build the UNet from a diffusers-layout single safetensors file with
314    /// LoRA wrappers active. Mirrors candle's `build_unet`: open the mmap,
315    /// wrap it in a `SimpleBackend`, layer `Sd15LoraBackend` on top, and feed
316    /// the resulting `VarBuilder` to `UNet2DConditionModel::new`.
317    fn build_unet_diffusers_with_lora(
318        &self,
319        sd_config: &stable_diffusion::StableDiffusionConfig,
320        device: &Device,
321        dtype: DType,
322    ) -> Result<stable_diffusion::unet_2d::UNet2DConditionModel> {
323        use candle_core::safetensors::MmapedSafetensors;
324        use candle_nn::VarBuilder;
325
326        let st = unsafe { MmapedSafetensors::multi(&[&self.base.paths.transformer])? };
327
328        struct MmapBackend {
329            st: MmapedSafetensors,
330        }
331        impl candle_nn::var_builder::SimpleBackend for MmapBackend {
332            fn get(
333                &self,
334                _s: candle_core::Shape,
335                name: &str,
336                _h: candle_nn::Init,
337                dtype: DType,
338                dev: &Device,
339            ) -> candle_core::Result<Tensor> {
340                let t = self.st.load(name, dev)?;
341                if t.dtype() != dtype {
342                    t.to_dtype(dtype)
343                } else {
344                    Ok(t)
345                }
346            }
347            fn get_unchecked(
348                &self,
349                name: &str,
350                dtype: DType,
351                dev: &Device,
352            ) -> candle_core::Result<Tensor> {
353                let t = self.st.load(name, dev)?;
354                if t.dtype() != dtype {
355                    t.to_dtype(dtype)
356                } else {
357                    Ok(t)
358                }
359            }
360            fn contains_tensor(&self, name: &str) -> bool {
361                self.st.get(name).is_ok()
362            }
363        }
364        let inner: Box<dyn candle_nn::var_builder::SimpleBackend> = Box::new(MmapBackend { st });
365        let wrapped = self.wrap_with_loras(inner)?;
366        let vb = VarBuilder::from_backend(wrapped, dtype, device.clone());
367        Ok(stable_diffusion::unet_2d::UNet2DConditionModel::new(
368            vb,
369            4,
370            4,
371            false,
372            sd_config.unet().clone(),
373        )?)
374    }
375
376    /// Same as `build_unet_single_file` but with the LoRA wrapper layered on
377    /// top of `SingleFileBackend`. The SD1.5 `SingleFileBackend` translates
378    /// diffusers candle keys onto the mmap'd A1111 checkpoint; the LoRA
379    /// wrapper then intercepts the merged tensor and adds the per-layer
380    /// delta before it lands on the UNet constructor.
381    fn build_unet_single_file_with_lora(
382        &self,
383        single_file: &std::path::Path,
384        remap: &crate::loader::Sd15Remap,
385        sd_config: &stable_diffusion::StableDiffusionConfig,
386        device: &Device,
387        dtype: DType,
388    ) -> Result<stable_diffusion::unet_2d::UNet2DConditionModel> {
389        use crate::loader::SingleFileBackend;
390        use candle_nn::VarBuilder;
391
392        let backend = SingleFileBackend::from_sd15_unet(single_file, remap)?;
393        let inner: Box<dyn candle_nn::var_builder::SimpleBackend> = Box::new(backend);
394        let wrapped = self.wrap_with_loras(inner)?;
395        let vb = VarBuilder::from_backend(wrapped, dtype, device.clone());
396        Ok(stable_diffusion::unet_2d::UNet2DConditionModel::new(
397            vb,
398            4,
399            4,
400            false,
401            sd_config.unet().clone(),
402        )?)
403    }
404
405    /// Wrap an `inner` SimpleBackend with the SD1.5 LoRA backend. Resolves
406    /// the `pending_loras` list into `Sd15LoraSpec`s (parsed-LoRA cache hits
407    /// keep adapter parsing cheap across requests).
408    fn wrap_with_loras(
409        &self,
410        inner: Box<dyn candle_nn::var_builder::SimpleBackend>,
411    ) -> Result<Box<dyn candle_nn::var_builder::SimpleBackend>> {
412        let adapters = super::lora::load_lora_adapters(&self.pending_loras, &self.base.progress)?;
413        let specs: Vec<super::lora::Sd15LoraSpec<'_>> = adapters
414            .iter()
415            .zip(self.pending_loras.iter())
416            .map(|(adapter, w)| super::lora::Sd15LoraSpec {
417                adapter: adapter.as_ref(),
418                scale: w.scale,
419                path_hash: super::lora::lora_path_hash(&w.path),
420            })
421            .collect();
422        super::lora::wrap_backend_with_lora(inner, &specs, &self.base.progress, None)
423    }
424
425    /// VAE load with the same single-file vs diffusers branch as
426    /// `build_unet_for_strategy`. Sequential SD1.5 loads the VAE twice
427    /// (once for img2img encode, once for post-denoise decode), so this
428    /// helper keeps the branch logic in a single place.
429    fn build_vae_for_strategy(
430        &self,
431        sd_config: &stable_diffusion::StableDiffusionConfig,
432        device: &Device,
433        dtype: DType,
434    ) -> Result<stable_diffusion::vae::AutoEncoderKL> {
435        if let Some(single_file) = self.single_file_path.as_ref() {
436            let remap = Self::load_sd15_remap(single_file)?;
437            Self::build_vae_single_file(single_file, &remap, sd_config, device, dtype)
438        } else {
439            self.build_vae_diffusers(sd_config, device, dtype)
440        }
441    }
442
443    #[cfg(test)]
444    fn load_vae_cpu_tensors(&self) -> Result<Option<Arc<HashMap<String, Tensor>>>> {
445        self.load_vae_cpu_tensors_for_path(&self.base.paths.vae)
446    }
447
448    fn load_vae_cpu_tensors_for_path(
449        &self,
450        vae_path: &Path,
451    ) -> Result<Option<Arc<HashMap<String, Tensor>>>> {
452        let Some(shared_pool) = &self.shared_pool else {
453            return Ok(None);
454        };
455        shared_pool
456            .lock()
457            .unwrap()
458            .load_safetensors_cpu_tensors(std::slice::from_ref(&vae_path))
459    }
460
461    fn load_vae_var_builder<'a>(
462        &self,
463        vae_path: &Path,
464        dtype: DType,
465        device: &Device,
466        component: &str,
467    ) -> Result<candle_nn::VarBuilder<'a>> {
468        if let Some(tensors) = self.load_vae_cpu_tensors_for_path(vae_path)? {
469            return Ok(crate::encoders::park::varbuilder_from_parked(
470                tensors.as_ref(),
471                dtype,
472                device,
473            ));
474        }
475
476        crate::weight_loader::load_safetensors_with_progress(
477            std::slice::from_ref(&vae_path),
478            dtype,
479            device,
480            component,
481            &self.base.progress,
482        )
483    }
484
485    fn build_vae_diffusers(
486        &self,
487        sd_config: &stable_diffusion::StableDiffusionConfig,
488        device: &Device,
489        dtype: DType,
490    ) -> Result<stable_diffusion::vae::AutoEncoderKL> {
491        let vb = self.load_vae_var_builder(&self.base.paths.vae, dtype, device, "VAE")?;
492        Ok(stable_diffusion::vae::AutoEncoderKL::new(
493            vb,
494            3,
495            3,
496            sd_config.autoencoder().clone(),
497        )?)
498    }
499
500    /// Header-parse the single-file checkpoint and build the SD1.5
501    /// diffusers→A1111 remap. Cheap (no tensor data is mmap'd) — sequential
502    /// reload calls this each time a component is reloaded after dropping.
503    fn load_sd15_remap(single_file: &std::path::Path) -> Result<crate::loader::Sd15Remap> {
504        use crate::loader::{build_sd15_remap, single_file as single_file_loader};
505        use mold_catalog::families::Family;
506        let bundle = single_file_loader::load(single_file, Family::Sd15)
507            .map_err(|e| anyhow::anyhow!("partition single-file SD1.5 checkpoint: {e}"))?;
508        build_sd15_remap(&bundle)
509            .map_err(|e| anyhow::anyhow!("build SD1.5 diffusers→A1111 remap: {e}"))
510    }
511
512    /// Build a UNet from a Civitai single-file checkpoint via a UNet-scoped
513    /// `SingleFileBackend`. Used by both eager `load_components_single_file`
514    /// and sequential `generate_sequential`. SD1.5 component keyspaces are
515    /// disjoint (no collision risk), but per-scope factories keep parity
516    /// with SDXL's collision-avoidance pattern.
517    fn build_unet_single_file(
518        single_file: &std::path::Path,
519        remap: &crate::loader::Sd15Remap,
520        sd_config: &stable_diffusion::StableDiffusionConfig,
521        device: &Device,
522        dtype: DType,
523    ) -> Result<stable_diffusion::unet_2d::UNet2DConditionModel> {
524        use crate::loader::SingleFileBackend;
525        use candle_nn::VarBuilder;
526        let backend = SingleFileBackend::from_sd15_unet(single_file, remap)?;
527        let vb = VarBuilder::from_backend(Box::new(backend), dtype, device.clone());
528        Ok(stable_diffusion::unet_2d::UNet2DConditionModel::new(
529            vb,
530            4,
531            4,
532            false,
533            sd_config.unet().clone(),
534        )?)
535    }
536
537    /// Build a VAE from a Civitai single-file checkpoint via a VAE-scoped
538    /// `SingleFileBackend`.
539    fn build_vae_single_file(
540        single_file: &std::path::Path,
541        remap: &crate::loader::Sd15Remap,
542        sd_config: &stable_diffusion::StableDiffusionConfig,
543        device: &Device,
544        dtype: DType,
545    ) -> Result<stable_diffusion::vae::AutoEncoderKL> {
546        use crate::loader::SingleFileBackend;
547        use candle_nn::VarBuilder;
548        let backend = SingleFileBackend::from_sd15_vae(single_file, remap)?;
549        let vb = VarBuilder::from_backend(Box::new(backend), dtype, device.clone());
550        Ok(stable_diffusion::vae::AutoEncoderKL::new(
551            vb,
552            3,
553            3,
554            sd_config.autoencoder().clone(),
555        )?)
556    }
557
558    /// Build CLIP-L from a Civitai single-file checkpoint via a CLIP-L-scoped
559    /// `SingleFileBackend`. SD1.5 has only one CLIP-L (no CLIP-G).
560    fn build_clip_single_file(
561        single_file: &std::path::Path,
562        remap: &crate::loader::Sd15Remap,
563        clip_config: &stable_diffusion::clip::Config,
564        clip_device: &Device,
565    ) -> Result<stable_diffusion::clip::ClipTextTransformer> {
566        use crate::loader::SingleFileBackend;
567        use candle_nn::VarBuilder;
568        let backend = SingleFileBackend::from_sd15_clip_l(single_file, remap)?;
569        let vb = VarBuilder::from_backend(Box::new(backend), DType::F32, clip_device.clone());
570        Ok(stable_diffusion::clip::ClipTextTransformer::new(
571            vb,
572            clip_config,
573        )?)
574    }
575
576    /// Load all SD1.5 model components (Eager mode).
577    ///
578    /// Two materialisation paths share this entry point:
579    ///
580    /// - **Diffusers-layout** (`single_file_path == None`) — calls candle's
581    ///   `build_unet` / `build_vae` / `build_clip_transformer` helpers,
582    ///   each of which mmaps a separate component file through the
583    ///   default `VarBuilder::from_mmaped_safetensors`.
584    /// - **Single-file** (`single_file_path == Some(path)`) — header-parses
585    ///   the checkpoint, runs `build_sd15_remap` to produce a
586    ///   diffusers→A1111 lookup, and feeds candle's per-component
587    ///   constructors a `VarBuilder::from_backend(SingleFileBackend)` that
588    ///   translates each `vb.get(diffusers_key)` into an mmap'd A1111 read.
589    ///   Three backends share the underlying mmap (one per component VB);
590    ///   peak memory is identical to the diffusers path.
591    ///
592    /// On error, `self.base.loaded` remains `None` — all components are
593    /// assembled into local variables and only stored on success, so
594    /// partial loads cannot leave the engine in an inconsistent state.
595    pub fn load(&mut self) -> Result<()> {
596        if self.base.loaded.is_some() {
597            return Ok(());
598        }
599
600        // Sequential mode defers loading to generate_sequential()
601        if self.base.load_strategy == LoadStrategy::Sequential {
602            return Ok(());
603        }
604
605        let (clip_encoder, clip_tokenizer) = self.validate_paths()?;
606
607        tracing::info!(model = %self.base.model_name, "loading SD1.5 model components...");
608
609        let device = crate::device::create_device(self.base.gpu_ordinal, &self.base.progress)?;
610        let dtype = if crate::device::is_gpu(&device) {
611            DType::F16
612        } else {
613            DType::F32
614        };
615
616        let sd_config = self.sd_config();
617
618        // Tier 1: honor `placement.text_encoders` override (group knob).
619        let tier1 = self
620            .pending_placement
621            .as_ref()
622            .map(|p| p.text_encoders)
623            .unwrap_or_default();
624        let clip_device = crate::device::resolve_device(Some(tier1), || Ok(device.clone()))?;
625
626        // Resolve VAE precision once at load — see LoadedSD15::vae_dtype.
627        let vae_dtype = crate::device::resolve_vae_dtype(dtype);
628        let (unet, vae, clip) = if let Some(single_file) = self.single_file_path.clone() {
629            self.load_components_single_file(
630                &single_file,
631                &sd_config,
632                &device,
633                &clip_device,
634                dtype,
635                vae_dtype,
636            )?
637        } else {
638            self.load_components_diffusers(
639                &clip_encoder,
640                &sd_config,
641                &device,
642                &clip_device,
643                dtype,
644                vae_dtype,
645            )?
646        };
647
648        let tokenizer = self.load_clip_tokenizer(&clip_tokenizer)?;
649
650        self.base.loaded = Some(LoadedSD15 {
651            unet: Some(unet),
652            vae,
653            clip,
654            tokenizer,
655            sd_config,
656            device,
657            clip_device,
658            vae_dtype,
659            dtype,
660        });
661
662        tracing::info!(model = %self.base.model_name, "all SD1.5 components loaded successfully");
663        Ok(())
664    }
665
666    /// Diffusers-layout component loader (existing pre-2.6 path).
667    #[allow(clippy::too_many_arguments)]
668    fn load_components_diffusers(
669        &mut self,
670        clip_encoder: &std::path::Path,
671        sd_config: &stable_diffusion::StableDiffusionConfig,
672        device: &Device,
673        clip_device: &Device,
674        dtype: DType,
675        vae_dtype: DType,
676    ) -> Result<(
677        stable_diffusion::unet_2d::UNet2DConditionModel,
678        stable_diffusion::vae::AutoEncoderKL,
679        stable_diffusion::clip::ClipTextTransformer,
680    )> {
681        self.base.progress.stage_start("Loading UNet (GPU)");
682        let unet_start = Instant::now();
683        let unet = sd_config.build_unet(
684            &self.base.paths.transformer,
685            device,
686            4,     // in_channels
687            false, // use_flash_attn
688            dtype,
689        )?;
690        self.base
691            .progress
692            .stage_done("Loading UNet (GPU)", unet_start.elapsed());
693
694        self.base.progress.stage_start("Loading VAE (GPU)");
695        let vae_start = Instant::now();
696        let vae = self.build_vae_diffusers(sd_config, device, vae_dtype)?;
697        self.base
698            .progress
699            .stage_done("Loading VAE (GPU)", vae_start.elapsed());
700
701        self.base.progress.stage_start("Loading CLIP-L encoder");
702        let clip_start = Instant::now();
703        let clip = stable_diffusion::build_clip_transformer(
704            &sd_config.clip,
705            clip_encoder,
706            clip_device,
707            DType::F32,
708        )?;
709        self.base
710            .progress
711            .stage_done("Loading CLIP-L encoder", clip_start.elapsed());
712
713        Ok((unet, vae, clip))
714    }
715
716    /// Single-file (Civitai) component loader (phase 2.6 + 2.8.5).
717    ///
718    /// Header-parses the checkpoint, builds the diffusers→A1111 remap, wraps
719    /// it in a `SingleFileBackend` (which translates diffusers `vb.get(key)`
720    /// calls into reads from the mmap'd source tensors), then hands three
721    /// `VarBuilder::from_backend(SingleFileBackend)` instances to candle's
722    /// per-component constructors. UNet, VAE, and CLIP-L all read from the
723    /// same single-file mmap — no on-disk shard files needed.
724    ///
725    /// Reaches into `sd_config.unet()` / `.autoencoder()` accessors exposed
726    /// by candle-transformers-mold 0.9.12 (utensils/candle PR #1).
727    #[allow(clippy::too_many_arguments)]
728    fn load_components_single_file(
729        &mut self,
730        single_file: &std::path::Path,
731        sd_config: &stable_diffusion::StableDiffusionConfig,
732        device: &Device,
733        clip_device: &Device,
734        dtype: DType,
735        vae_dtype: DType,
736    ) -> Result<(
737        stable_diffusion::unet_2d::UNet2DConditionModel,
738        stable_diffusion::vae::AutoEncoderKL,
739        stable_diffusion::clip::ClipTextTransformer,
740    )> {
741        let remap = Self::load_sd15_remap(single_file)?;
742
743        self.base.progress.stage_start("Loading UNet (single-file)");
744        let unet_start = Instant::now();
745        let unet = Self::build_unet_single_file(single_file, &remap, sd_config, device, dtype)?;
746        self.base
747            .progress
748            .stage_done("Loading UNet (single-file)", unet_start.elapsed());
749
750        self.base.progress.stage_start("Loading VAE (single-file)");
751        let vae_start = Instant::now();
752        let vae = Self::build_vae_single_file(single_file, &remap, sd_config, device, vae_dtype)?;
753        self.base
754            .progress
755            .stage_done("Loading VAE (single-file)", vae_start.elapsed());
756
757        self.base
758            .progress
759            .stage_start("Loading CLIP-L (single-file)");
760        let clip_start = Instant::now();
761        let clip = Self::build_clip_single_file(single_file, &remap, &sd_config.clip, clip_device)?;
762        self.base
763            .progress
764            .stage_done("Loading CLIP-L (single-file)", clip_start.elapsed());
765
766        Ok((unet, vae, clip))
767    }
768
769    /// Tokenize a prompt for the CLIP encoder, padding/truncating to max_len tokens.
770    fn tokenize(
771        tokenizer: &tokenizers::Tokenizer,
772        prompt: &str,
773        max_len: usize,
774        device: &Device,
775    ) -> Result<Tensor> {
776        let encoding = tokenizer
777            .encode(prompt, true)
778            .map_err(|e| anyhow::anyhow!("tokenization failed: {e}"))?;
779        let mut ids = encoding.get_ids().to_vec();
780        ids.truncate(max_len);
781        // Pad with 0s (EOS/PAD token for CLIP)
782        while ids.len() < max_len {
783            ids.push(0);
784        }
785        let ids = ids.into_iter().map(|i| i as i64).collect::<Vec<_>>();
786        Ok(Tensor::new(ids, device)?.unsqueeze(0)?)
787    }
788
789    /// Run the denoising loop (shared between eager and sequential).
790    ///
791    /// `start_step` allows starting from a later timestep for img2img (0 = full txt2img).
792    #[allow(clippy::too_many_arguments)]
793    fn denoise_loop(
794        &self,
795        unet: &stable_diffusion::unet_2d::UNet2DConditionModel,
796        text_embeddings: &Tensor,
797        sched: Scheduler,
798        latents: &mut Tensor,
799        guidance: f64,
800        cfg_plus: bool,
801        steps: u32,
802        start_step: usize,
803        inpaint_ctx: Option<&crate::img_utils::InpaintContext>,
804        controlnet_ctx: Option<&ControlNetContext>,
805    ) -> Result<()> {
806        let use_cfg = cfg_active(guidance);
807        let mut scheduler = crate::scheduler::build_scheduler(
808            sched,
809            steps as usize,
810            PredictionType::Epsilon,
811            false,
812        )?;
813        let timesteps = scheduler.timesteps().to_vec();
814        let active_timesteps = &timesteps[start_step..];
815
816        // CFG++ requires the doubled `[uncond, cond]` forward AND a DDIM
817        // scheduler — see the SDXL counterpart for the full rationale.
818        let cfg_plus_schedule = if cfg_plus && use_cfg && matches!(sched, Scheduler::Ddim) {
819            Some(DdimAlphaSchedule::from_default(steps as usize))
820        } else {
821            if cfg_plus && !use_cfg {
822                tracing::warn!(
823                    guidance,
824                    "cfg_plus requested but cfg_scale ≈ 1.0 — falling back to standard step (no uncond available)"
825                );
826            } else if cfg_plus {
827                tracing::warn!(
828                    scheduler = ?sched,
829                    "cfg_plus requested but only DDIM is supported on SDXL/SD1.5 — falling back to standard step. Re-run with `--scheduler ddim` to enable CFG++."
830                );
831            }
832            None
833        };
834
835        let denoise_label = format!("Denoising ({} steps)", active_timesteps.len());
836        self.base.progress.stage_start(&denoise_label);
837        let denoise_start = Instant::now();
838
839        for (step_idx, &t) in active_timesteps.iter().enumerate() {
840            let step_start = std::time::Instant::now();
841            let latent_input = if use_cfg {
842                Tensor::cat(&[&*latents, &*latents], 0)?
843            } else {
844                latents.clone()
845            };
846
847            let latent_input = scheduler.scale_model_input(latent_input, t)?;
848
849            let noise_pred = if let Some(cn_ctx) = controlnet_ctx {
850                let (down_residuals, mid_residual) = cn_ctx.model.forward(
851                    &latent_input,
852                    t as f64,
853                    text_embeddings,
854                    &cn_ctx.control_tensor,
855                    cn_ctx.scale,
856                )?;
857                unet.forward_with_additional_residuals(
858                    &latent_input,
859                    t as f64,
860                    text_embeddings,
861                    Some(&down_residuals),
862                    Some(&mid_residual),
863                )?
864            } else {
865                unet.forward(&latent_input, t as f64, text_embeddings)?
866            };
867
868            // Hold onto the raw uncond row when CFG++ is active so we can use
869            // it as the renoise direction below.
870            let (noise_pred_blended, noise_pred_uncond_opt) = if use_cfg {
871                let chunks = noise_pred.chunk(2, 0)?;
872                let noise_pred_uncond = chunks[0].clone();
873                let noise_pred_cond = &chunks[1];
874                let blended =
875                    (&noise_pred_uncond + ((noise_pred_cond - &noise_pred_uncond)? * guidance)?)?;
876                (blended, Some(noise_pred_uncond))
877            } else {
878                (noise_pred, None)
879            };
880
881            *latents = match (cfg_plus_schedule.as_ref(), noise_pred_uncond_opt.as_ref()) {
882                (Some(ddim_sched), Some(eps_uncond)) => {
883                    ddim_sched.cfg_plus_step(&*latents, &noise_pred_blended, eps_uncond, t)?
884                }
885                _ => scheduler.step(&noise_pred_blended, t, &*latents)?,
886            };
887
888            // Inpainting: blend preserved regions back at current noise level
889            if let Some(ctx) = inpaint_ctx {
890                let noised_original =
891                    scheduler.add_noise(&ctx.original_latents, ctx.noise.clone(), t)?;
892                *latents = crate::img2img::blend_inpaint_latents(&*latents, ctx, &noised_original)?;
893            }
894
895            self.base.progress.emit(ProgressEvent::DenoiseStep {
896                step: step_idx + 1,
897                total: active_timesteps.len(),
898                elapsed: step_start.elapsed(),
899            });
900        }
901
902        self.base
903            .progress
904            .stage_done(&denoise_label, denoise_start.elapsed());
905        Ok(())
906    }
907
908    /// Prepare img2img latents: VAE encode source image, add noise at the appropriate timestep.
909    /// Returns (noised_latents, start_step, encoded_latents, noise) -- extra fields for inpainting.
910    ///
911    /// `dtype` drives the noise/denoise loop; `vae_dtype` may differ when
912    /// `MOLD_VAE_DTYPE` forces fp32 (encoded latents are cast back to `dtype`
913    /// before returning so the denoise loop runs at engine precision).
914    #[allow(clippy::too_many_arguments)]
915    fn prepare_img2img_latents(
916        &self,
917        vae: &stable_diffusion::vae::AutoEncoderKL,
918        source_bytes: &[u8],
919        width: u32,
920        height: u32,
921        strength: f64,
922        steps: u32,
923        sched: Scheduler,
924        seed: u64,
925        device: &Device,
926        dtype: DType,
927        vae_dtype: DType,
928    ) -> Result<(Tensor, usize, Tensor, Tensor)> {
929        use crate::img_utils::{decode_source_image, NormalizeRange};
930        let cache_key = image_size_cache_key(source_bytes, width, height);
931        let (encoded, cache_hit) = get_or_insert_cached_tensor(
932            &self.source_latent_cache,
933            cache_key,
934            device,
935            dtype,
936            || {
937                self.base
938                    .progress
939                    .stage_start("Encoding source image (VAE)");
940                let encode_start = Instant::now();
941
942                let source_tensor = decode_source_image(
943                    source_bytes,
944                    width,
945                    height,
946                    NormalizeRange::MinusOneToOne,
947                    device,
948                    vae_dtype,
949                )?;
950
951                let encoded = vae.encode(&source_tensor)?;
952                let encoded = (encoded.mode()? * VAE_SCALE)?;
953                // Cast back to engine dtype so the denoise loop stays at its
954                // natural precision when MOLD_VAE_DTYPE upgraded the VAE.
955                let encoded = encoded.to_dtype(dtype)?;
956
957                self.base
958                    .progress
959                    .stage_done("Encoding source image (VAE)", encode_start.elapsed());
960                Ok(encoded)
961            },
962        )?;
963        if cache_hit {
964            self.base.progress.cache_hit("source image latents");
965        }
966
967        let start_step = crate::img2img::img2img_start_index(steps as usize, strength);
968
969        // Build scheduler to get timesteps and add noise
970        let scheduler = crate::scheduler::build_scheduler(
971            sched,
972            steps as usize,
973            PredictionType::Epsilon,
974            false,
975        )?;
976        let timesteps = scheduler.timesteps().to_vec();
977
978        let latent_h = height as usize / 8;
979        let latent_w = width as usize / 8;
980        let noise =
981            crate::engine::seeded_randn(seed, &[1, 4, latent_h, latent_w], device, DType::F32)?;
982        let noise = noise.to_dtype(dtype)?;
983
984        // Add noise at the start timestep
985        let noised = if start_step < timesteps.len() {
986            scheduler.add_noise(&encoded, noise.clone(), timesteps[start_step])?
987        } else {
988            encoded.clone()
989        };
990
991        tracing::info!(
992            start_step,
993            total_steps = steps,
994            strength,
995            "img2img: starting from step {start_step}"
996        );
997
998        Ok((noised, start_step, encoded, noise))
999    }
1000
1001    /// Encode prompt with the single CLIP-L encoder.
1002    #[allow(clippy::too_many_arguments)]
1003    #[allow(clippy::too_many_arguments)]
1004    fn encode_prompt(
1005        &self,
1006        clip: &stable_diffusion::clip::ClipTextTransformer,
1007        tokenizer: &tokenizers::Tokenizer,
1008        prompt: &str,
1009        negative_prompt: &str,
1010        max_len: usize,
1011        device: &Device,
1012        clip_device: &Device,
1013        dtype: DType,
1014        guidance: f64,
1015    ) -> Result<Tensor> {
1016        // SD1.5 caches the **concatenated** `(uncond, cond)` tensor when CFG is
1017        // active, so the cache key must include the negative prompt and the
1018        // guidance scale. Keying on the positive prompt + guidance alone
1019        // returned a stale uncond branch when the user changed only the
1020        // negative prompt — silent wrong output. Mirrors the SD3 / SDXL fix.
1021        let cache_key = cfg_prompt_cache_key(prompt, negative_prompt, guidance);
1022        let (text_embeddings, cache_hit) =
1023            get_or_insert_cached_tensor(&self.prompt_cache, cache_key, device, dtype, || {
1024                let use_cfg = cfg_active(guidance);
1025
1026                self.base.progress.stage_start("Encoding prompt (CLIP-L)");
1027                let encode_start = Instant::now();
1028                let tokens = Self::tokenize(tokenizer, prompt, max_len, clip_device)?;
1029                let text_embeddings = clip.forward(&tokens)?;
1030                self.base
1031                    .progress
1032                    .stage_done("Encoding prompt (CLIP-L)", encode_start.elapsed());
1033
1034                let text_embeddings = if use_cfg {
1035                    let uncond_tokens =
1036                        Self::tokenize(tokenizer, negative_prompt, max_len, clip_device)?;
1037                    let uncond_embeddings = clip.forward(&uncond_tokens)?;
1038                    Tensor::cat(&[&uncond_embeddings, &text_embeddings], 0)?
1039                } else {
1040                    text_embeddings
1041                };
1042
1043                let text_embeddings = text_embeddings.to_device(device)?;
1044                Ok(text_embeddings.to_dtype(dtype)?)
1045            })?;
1046        if cache_hit {
1047            self.base.progress.cache_hit("prompt conditioning");
1048            return Ok(text_embeddings);
1049        }
1050        Ok(text_embeddings)
1051    }
1052
1053    fn cached_mask(
1054        &self,
1055        mask_bytes: &[u8],
1056        latent_h: usize,
1057        latent_w: usize,
1058        device: &Device,
1059        dtype: DType,
1060    ) -> Result<Tensor> {
1061        let key = latent_size_cache_key(mask_bytes, latent_h, latent_w);
1062        let (mask, cache_hit) =
1063            get_or_insert_cached_tensor(&self.mask_cache, key, device, dtype, || {
1064                crate::img_utils::decode_mask_image(mask_bytes, latent_h, latent_w, device, dtype)
1065            })?;
1066        if cache_hit {
1067            self.base.progress.cache_hit("inpaint mask");
1068            return Ok(mask);
1069        }
1070        Ok(mask)
1071    }
1072
1073    /// Load ControlNet model and prepare control tensor from the request.
1074    fn load_controlnet(
1075        &self,
1076        req: &GenerateRequest,
1077        device: &Device,
1078        dtype: DType,
1079    ) -> Result<Option<ControlNetContext>> {
1080        let (control_bytes, control_model_name, scale) =
1081            match (req.control_image.as_ref(), req.control_model.as_ref()) {
1082                (Some(bytes), Some(name)) => (bytes, name, req.control_scale),
1083                _ => return Ok(None),
1084            };
1085
1086        self.base.progress.stage_start("Loading ControlNet");
1087        let cn_start = Instant::now();
1088
1089        // Resolve ControlNet model weights path
1090        let config = mold_core::Config::load_or_default();
1091        let cn_paths =
1092            mold_core::ModelPaths::resolve(control_model_name, &config).ok_or_else(|| {
1093                anyhow::anyhow!(
1094                    "ControlNet model '{}' not found. Pull it with: mold pull {}",
1095                    control_model_name,
1096                    control_model_name,
1097                )
1098            })?;
1099
1100        // Use default SD1.5 UNet config for ControlNet architecture
1101        let unet_config =
1102            candle_transformers::models::stable_diffusion::unet_2d::UNet2DConditionModelConfig::default();
1103        let model = ControlNetModel::load(
1104            &cn_paths.transformer,
1105            device,
1106            dtype,
1107            unet_config,
1108            &self.base.progress,
1109        )?;
1110
1111        self.base
1112            .progress
1113            .stage_done("Loading ControlNet", cn_start.elapsed());
1114
1115        // Decode and preprocess control image
1116        self.base
1117            .progress
1118            .stage_start("Preprocessing control image");
1119        let preprocess_start = Instant::now();
1120
1121        let control_key = image_size_cache_key(control_bytes, req.width, req.height);
1122        let (control_tensor, cache_hit) = get_or_insert_cached_tensor(
1123            &self.control_tensor_cache,
1124            control_key,
1125            device,
1126            dtype,
1127            || {
1128                crate::img_utils::decode_source_image(
1129                    control_bytes,
1130                    req.width,
1131                    req.height,
1132                    crate::img_utils::NormalizeRange::ZeroToOne,
1133                    device,
1134                    dtype,
1135                )
1136            },
1137        )?;
1138        if cache_hit {
1139            self.base.progress.cache_hit("control preprocessing");
1140        }
1141
1142        self.base
1143            .progress
1144            .stage_done("Preprocessing control image", preprocess_start.elapsed());
1145
1146        tracing::info!(
1147            control_model = %control_model_name,
1148            scale,
1149            "ControlNet loaded and control image preprocessed"
1150        );
1151
1152        Ok(Some(ControlNetContext {
1153            model,
1154            control_tensor,
1155            scale,
1156        }))
1157    }
1158
1159    /// Generate an image using sequential loading strategy.
1160    ///
1161    /// Loads components one at a time and drops them when done:
1162    /// 1. Load CLIP-L → encode → drop CLIP-L
1163    /// 2. Load UNet → denoise → drop UNet
1164    /// 3. Load VAE → decode → drop VAE
1165    fn generate_sequential(&mut self, req: &GenerateRequest) -> Result<GenerateResponse> {
1166        let (clip_encoder, clip_tokenizer) = self.validate_paths()?;
1167
1168        // Check memory budget
1169        if let Some(warning) = check_memory_budget(&self.base.paths, LoadStrategy::Sequential) {
1170            self.base.progress.info(&warning);
1171        }
1172
1173        let device = crate::device::create_device(self.base.gpu_ordinal, &self.base.progress)?;
1174        let dtype = if crate::device::is_gpu(&device) {
1175            DType::F16
1176        } else {
1177            DType::F32
1178        };
1179
1180        let sd_config = self.sd_config();
1181        let max_len = sd_config.clip.max_position_embeddings;
1182
1183        let start = Instant::now();
1184        let seed = req.seed.unwrap_or_else(rand_seed);
1185
1186        let width = req.width as usize;
1187        let height = req.height as usize;
1188        let guidance = req.guidance;
1189
1190        tracing::info!(
1191            prompt = %req.prompt,
1192            seed, width, height,
1193            steps = req.steps,
1194            guidance,
1195            "starting sequential SD1.5 generation"
1196        );
1197
1198        self.base
1199            .progress
1200            .info("Using sequential loading (load-use-drop) to minimize peak memory");
1201
1202        // --- Phase 1: Encode prompt (check cache first to skip encoder load) ---
1203        let neg = req.negative_prompt.as_deref().unwrap_or("");
1204        let cache_key = cfg_prompt_cache_key(&req.prompt, neg, guidance);
1205        let text_embeddings = if let Some(tensor) =
1206            restore_cached_tensor(&self.prompt_cache, &cache_key, &device, dtype)?
1207        {
1208            self.base.progress.cache_hit("prompt conditioning");
1209            tensor
1210        } else {
1211            if let Some(status) = memory_status_string() {
1212                self.base.progress.info(&status);
1213            }
1214
1215            let tokenizer = self.load_clip_tokenizer(&clip_tokenizer)?;
1216
1217            let tier1 = self
1218                .pending_placement
1219                .as_ref()
1220                .map(|p| p.text_encoders)
1221                .unwrap_or_default();
1222            let clip_device = crate::device::resolve_device(Some(tier1), || Ok(device.clone()))?;
1223            // Branch on single_file_path so cv:<id>-style Civitai checkpoints
1224            // (A1111 naming) go through `SingleFileBackend` just like eager
1225            // mode does. Without this, candle's diffusers-keyed
1226            // `build_clip_transformer` errors with
1227            // "cannot find tensor text_model.embeddings.token_embedding.weight".
1228            let clip = if let Some(single_file) = self.single_file_path.clone() {
1229                let remap = Self::load_sd15_remap(&single_file)?;
1230                self.base
1231                    .progress
1232                    .stage_start("Loading CLIP-L (single-file)");
1233                let clip_start = Instant::now();
1234                let clip = Self::build_clip_single_file(
1235                    &single_file,
1236                    &remap,
1237                    &sd_config.clip,
1238                    &clip_device,
1239                )?;
1240                self.base
1241                    .progress
1242                    .stage_done("Loading CLIP-L (single-file)", clip_start.elapsed());
1243                clip
1244            } else {
1245                self.base.progress.stage_start("Loading CLIP-L encoder");
1246                let clip_start = Instant::now();
1247                let clip = stable_diffusion::build_clip_transformer(
1248                    &sd_config.clip,
1249                    &clip_encoder,
1250                    &clip_device,
1251                    DType::F32,
1252                )?;
1253                self.base
1254                    .progress
1255                    .stage_done("Loading CLIP-L encoder", clip_start.elapsed());
1256                clip
1257            };
1258
1259            let text_embeddings = self.encode_prompt(
1260                &clip,
1261                &tokenizer,
1262                &req.prompt,
1263                neg,
1264                max_len,
1265                &device,
1266                &clip_device,
1267                dtype,
1268                guidance,
1269            )?;
1270
1271            drop(clip);
1272            self.base.progress.info("Freed CLIP-L encoder");
1273            tracing::info!("CLIP encoder dropped (sequential mode)");
1274
1275            text_embeddings
1276        };
1277
1278        // --- Phase 2: Load UNet and denoise ---
1279        let unet_size = std::fs::metadata(&self.base.paths.transformer)
1280            .map(|m| m.len())
1281            .unwrap_or(0);
1282        // SD1.5 runs CFG by default → batch=2 unless guidance ≈ 1 (LCM).
1283        let unet_batch = if req.guidance > 1.0 { 2 } else { 1 };
1284        let unet_activation_budget = crate::device::activation_bytes(
1285            req.width,
1286            req.height,
1287            unet_batch,
1288            crate::device::dtype_bytes(dtype),
1289            crate::device::ActivationFamily::SdxlUnet,
1290        );
1291        preflight_memory_check("UNet", unet_size, unet_activation_budget)?;
1292        if let Some(status) = memory_status_string() {
1293            self.base.progress.info(&status);
1294        }
1295
1296        self.base.progress.stage_start("Loading UNet (GPU)");
1297        let unet_start = Instant::now();
1298        let unet = self.build_unet_for_strategy(&sd_config, &device, dtype)?;
1299        self.base
1300            .progress
1301            .stage_done("Loading UNet (GPU)", unet_start.elapsed());
1302
1303        let sched = req.scheduler.unwrap_or(self.scheduler);
1304        let is_img2img = req.source_image.is_some();
1305
1306        // For img2img in sequential mode, we need the VAE for encoding before the UNet
1307        let (mut latents, start_step, inpaint_ctx) = if let Some(ref source_bytes) =
1308            req.source_image
1309        {
1310            self.base
1311                .progress
1312                .info("img2img mode: encoding source image before denoising");
1313
1314            // Load VAE first for encoding
1315            self.base.progress.stage_start("Loading VAE (GPU)");
1316            let vae_start = Instant::now();
1317            let vae_dtype = crate::device::resolve_vae_dtype(dtype);
1318            let vae = self.build_vae_for_strategy(&sd_config, &device, vae_dtype)?;
1319            self.base
1320                .progress
1321                .stage_done("Loading VAE (GPU)", vae_start.elapsed());
1322
1323            let (latents, start_step, encoded, noise) = self.prepare_img2img_latents(
1324                &vae,
1325                source_bytes,
1326                req.width,
1327                req.height,
1328                req.strength,
1329                req.steps,
1330                sched,
1331                seed,
1332                &device,
1333                dtype,
1334                vae_dtype,
1335            )?;
1336
1337            // Drop VAE to free memory for UNet
1338            let inpaint_ctx = if let Some(ref mask_bytes) = req.mask_image {
1339                let mask = self.cached_mask(mask_bytes, height / 8, width / 8, &device, dtype)?;
1340                Some(crate::img_utils::InpaintContext {
1341                    original_latents: encoded,
1342                    mask,
1343                    noise,
1344                })
1345            } else {
1346                None
1347            };
1348
1349            drop(vae);
1350            self.base
1351                .progress
1352                .info("Freed VAE (will reload for decode)");
1353            device.synchronize()?;
1354
1355            (latents, start_step, inpaint_ctx)
1356        } else {
1357            let latent_h = height / 8;
1358            let latent_w = width / 8;
1359            let init_scheduler = crate::scheduler::build_scheduler(
1360                sched,
1361                req.steps as usize,
1362                PredictionType::Epsilon,
1363                false,
1364            )?;
1365            let init_noise_sigma = init_scheduler.init_noise_sigma();
1366            drop(init_scheduler);
1367            let latents = (crate::engine::seeded_randn(
1368                seed,
1369                &[1, 4, latent_h, latent_w],
1370                &device,
1371                DType::F32,
1372            )? * init_noise_sigma)?;
1373            (latents.to_dtype(dtype)?, 0, None)
1374        };
1375
1376        // Load ControlNet if requested
1377        let controlnet_ctx = self.load_controlnet(req, &device, dtype)?;
1378
1379        self.denoise_loop(
1380            &unet,
1381            &text_embeddings,
1382            sched,
1383            &mut latents,
1384            guidance,
1385            resolve_cfg_plus(req),
1386            req.steps,
1387            start_step,
1388            inpaint_ctx.as_ref(),
1389            controlnet_ctx.as_ref(),
1390        )?;
1391
1392        // Drop UNet to free memory for VAE decode
1393        drop(controlnet_ctx);
1394        drop(inpaint_ctx);
1395        drop(unet);
1396        drop(text_embeddings);
1397        device.synchronize()?;
1398        self.base.progress.info("Freed UNet");
1399        tracing::info!("UNet dropped (sequential mode)");
1400
1401        // --- Phase 3: Load VAE and decode ---
1402        // (img2img already loaded+dropped VAE above; reload for decode)
1403        let vae_load_label = if is_img2img {
1404            "Reloading VAE (GPU)"
1405        } else {
1406            "Loading VAE (GPU)"
1407        };
1408        self.base.progress.stage_start(vae_load_label);
1409        let vae_start = Instant::now();
1410        let vae_dtype = crate::device::resolve_vae_dtype(dtype);
1411        let vae = self.build_vae_for_strategy(&sd_config, &device, vae_dtype)?;
1412        self.base
1413            .progress
1414            .stage_done(vae_load_label, vae_start.elapsed());
1415
1416        self.base.progress.stage_start("VAE decode");
1417        let vae_decode_start = Instant::now();
1418
1419        let latents = (latents / VAE_SCALE)?;
1420        let latents_for_vae = latents.to_dtype(vae_dtype)?;
1421        let device_for_sync = device.clone();
1422        let img = crate::vae_tiling::decode_with_oom_fallback(
1423            &latents_for_vae,
1424            |t| vae.decode(t).map_err(Into::into),
1425            || {
1426                if let Err(e) = device_for_sync.synchronize() {
1427                    tracing::warn!(
1428                        "SD1.5 (sequential) device.synchronize() after VAE OOM failed: {e}"
1429                    );
1430                }
1431            },
1432        )?;
1433
1434        let img = ((img / 2.)? + 0.5)?.clamp(0f32, 1f32)?;
1435        let img = (img * 255.)?.to_dtype(DType::U8)?;
1436        let img = img.squeeze(0)?;
1437
1438        self.base
1439            .progress
1440            .stage_done("VAE decode", vae_decode_start.elapsed());
1441
1442        let output_metadata = build_output_metadata(req, seed, Some(sched));
1443        let image_bytes = encode_image(
1444            &img,
1445            req.resolved_output_format(),
1446            req.width,
1447            req.height,
1448            output_metadata.as_ref(),
1449        )?;
1450
1451        let generation_time_ms = start.elapsed().as_millis() as u64;
1452        tracing::info!(
1453            generation_time_ms,
1454            seed,
1455            "sequential SD1.5 generation complete"
1456        );
1457
1458        Ok(GenerateResponse {
1459            images: vec![ImageData {
1460                data: image_bytes,
1461                format: req.resolved_output_format(),
1462                width: req.width,
1463                height: req.height,
1464                index: 0,
1465            }],
1466            generation_time_ms,
1467            model: req.model.clone(),
1468            seed_used: seed,
1469            video: None,
1470            gpu: None,
1471        })
1472    }
1473}
1474
1475impl SD15Engine {
1476    fn generate_inner(&mut self, req: &GenerateRequest) -> Result<GenerateResponse> {
1477        // Sequential mode: load-use-drop each component
1478        if self.base.load_strategy == LoadStrategy::Sequential {
1479            return self.generate_sequential(req);
1480        }
1481
1482        // Eager mode: if the requested LoRA stack differs from what's merged
1483        // into the loaded UNet, drop it now so `reload_unet_if_needed` rebuilds
1484        // it with the new wrapper.
1485        let requested_stack = lora_stack_fingerprint(&self.pending_loras);
1486        if requested_stack != self.active_lora_fingerprint {
1487            if let Some(loaded) = self.base.loaded.as_mut() {
1488                if loaded.unet.is_some() {
1489                    loaded.unet = None;
1490                    loaded.device.synchronize()?;
1491                    tracing::info!("SD1.5 UNet dropped (LoRA stack changed)");
1492                }
1493            }
1494            self.active_lora_fingerprint = requested_stack;
1495        }
1496
1497        // Eager mode: reload UNet if dropped after previous VAE decode
1498        self.reload_unet_if_needed()?;
1499
1500        let loaded = self
1501            .base
1502            .loaded
1503            .as_ref()
1504            .ok_or_else(|| anyhow::anyhow!("model not loaded — call load() first"))?;
1505
1506        let start = Instant::now();
1507        let seed = req.seed.unwrap_or_else(rand_seed);
1508
1509        let width = req.width as usize;
1510        let height = req.height as usize;
1511        let guidance = req.guidance;
1512
1513        tracing::info!(
1514            prompt = %req.prompt,
1515            seed, width, height,
1516            steps = req.steps,
1517            guidance,
1518            scheduler = %self.scheduler,
1519            "starting SD1.5 generation"
1520        );
1521
1522        // 1. Encode prompt with CLIP-L
1523        let max_len = loaded.sd_config.clip.max_position_embeddings;
1524        let neg = req.negative_prompt.as_deref().unwrap_or("");
1525        let text_embeddings = self.encode_prompt(
1526            &loaded.clip,
1527            &loaded.tokenizer,
1528            &req.prompt,
1529            neg,
1530            max_len,
1531            &loaded.device,
1532            &loaded.clip_device,
1533            loaded.dtype,
1534            guidance,
1535        )?;
1536
1537        // 2. Build scheduler and create initial latents
1538        let sched = req.scheduler.unwrap_or(self.scheduler);
1539
1540        let (mut latents, start_step, inpaint_ctx) =
1541            if let Some(ref source_bytes) = req.source_image {
1542                let (latents, start_step, encoded, noise) = self.prepare_img2img_latents(
1543                    &loaded.vae,
1544                    source_bytes,
1545                    req.width,
1546                    req.height,
1547                    req.strength,
1548                    req.steps,
1549                    sched,
1550                    seed,
1551                    &loaded.device,
1552                    loaded.dtype,
1553                    loaded.vae_dtype,
1554                )?;
1555                let inpaint_ctx = if let Some(ref mask_bytes) = req.mask_image {
1556                    let mask = self.cached_mask(
1557                        mask_bytes,
1558                        height / 8,
1559                        width / 8,
1560                        &loaded.device,
1561                        loaded.dtype,
1562                    )?;
1563                    Some(crate::img_utils::InpaintContext {
1564                        original_latents: encoded,
1565                        mask,
1566                        noise,
1567                    })
1568                } else {
1569                    None
1570                };
1571                (latents, start_step, inpaint_ctx)
1572            } else {
1573                let latent_h = height / 8;
1574                let latent_w = width / 8;
1575                let init_scheduler = crate::scheduler::build_scheduler(
1576                    sched,
1577                    req.steps as usize,
1578                    PredictionType::Epsilon,
1579                    false,
1580                )?;
1581                let init_noise_sigma = init_scheduler.init_noise_sigma();
1582                drop(init_scheduler);
1583                let latents = (crate::engine::seeded_randn(
1584                    seed,
1585                    &[1, 4, latent_h, latent_w],
1586                    &loaded.device,
1587                    DType::F32,
1588                )? * init_noise_sigma)?;
1589                (latents.to_dtype(loaded.dtype)?, 0, None)
1590            };
1591
1592        // 3. Load ControlNet if requested
1593        let controlnet_ctx = self.load_controlnet(req, &loaded.device, loaded.dtype)?;
1594
1595        // 4. Denoising loop
1596        let unet = loaded
1597            .unet
1598            .as_ref()
1599            .ok_or_else(|| anyhow::anyhow!("UNet not loaded"))?;
1600        self.denoise_loop(
1601            unet,
1602            &text_embeddings,
1603            sched,
1604            &mut latents,
1605            guidance,
1606            resolve_cfg_plus(req),
1607            req.steps,
1608            start_step,
1609            inpaint_ctx.as_ref(),
1610            controlnet_ctx.as_ref(),
1611        )?;
1612
1613        // Drop UNet before VAE decode to free VRAM for conv2d intermediates.
1614        drop(controlnet_ctx);
1615        drop(inpaint_ctx);
1616        // End the immutable borrow of `loaded` so we can take a mutable one.
1617        let _ = loaded;
1618        let loaded = self.base.loaded.as_mut().unwrap();
1619        loaded.unet = None;
1620        loaded.device.synchronize()?;
1621        tracing::info!("UNet dropped to free VRAM for VAE decode");
1622        // Re-borrow as immutable for VAE decode + output encoding.
1623        let _ = loaded;
1624        let loaded = self.base.loaded.as_ref().unwrap();
1625
1626        // 5. VAE decode
1627        self.base.progress.stage_start("VAE decode");
1628        let vae_start = Instant::now();
1629
1630        let latents = (latents / VAE_SCALE)?;
1631        let latents_for_vae = latents.to_dtype(loaded.vae_dtype)?;
1632        let vae = &loaded.vae;
1633        let device_for_sync = loaded.device.clone();
1634        let img = crate::vae_tiling::decode_with_oom_fallback(
1635            &latents_for_vae,
1636            |t| vae.decode(t).map_err(Into::into),
1637            || {
1638                if let Err(e) = device_for_sync.synchronize() {
1639                    tracing::warn!(
1640                        "SD1.5 (parallel) device.synchronize() after VAE OOM failed: {e}"
1641                    );
1642                }
1643            },
1644        )?;
1645
1646        let img = ((img / 2.)? + 0.5)?.clamp(0f32, 1f32)?;
1647        let img = (img * 255.)?.to_dtype(DType::U8)?;
1648        let img = img.squeeze(0)?;
1649
1650        self.base
1651            .progress
1652            .stage_done("VAE decode", vae_start.elapsed());
1653
1654        // 5. Encode to image format
1655        let output_metadata = build_output_metadata(req, seed, Some(sched));
1656        let image_bytes = encode_image(
1657            &img,
1658            req.resolved_output_format(),
1659            req.width,
1660            req.height,
1661            output_metadata.as_ref(),
1662        )?;
1663
1664        let generation_time_ms = start.elapsed().as_millis() as u64;
1665        tracing::info!(generation_time_ms, seed, "SD1.5 generation complete");
1666
1667        Ok(GenerateResponse {
1668            images: vec![ImageData {
1669                data: image_bytes,
1670                format: req.resolved_output_format(),
1671                width: req.width,
1672                height: req.height,
1673                index: 0,
1674            }],
1675            generation_time_ms,
1676            model: req.model.clone(),
1677            seed_used: seed,
1678            video: None,
1679            gpu: None,
1680        })
1681    }
1682}
1683
1684impl InferenceEngine for SD15Engine {
1685    fn generate(&mut self, req: &GenerateRequest) -> Result<GenerateResponse> {
1686        self.pending_placement = req.placement.clone();
1687        self.pending_loras = super::lora::effective_sd15_loras(req);
1688        let result = self.generate_inner(req);
1689        self.pending_placement = None;
1690        self.pending_loras.clear();
1691        result
1692    }
1693
1694    fn model_name(&self) -> &str {
1695        self.base.model_name()
1696    }
1697
1698    fn is_loaded(&self) -> bool {
1699        self.base.is_loaded()
1700    }
1701
1702    fn load(&mut self) -> Result<()> {
1703        SD15Engine::load(self)
1704    }
1705
1706    fn unload(&mut self) {
1707        self.base.unload();
1708        clear_cache(&self.prompt_cache);
1709        clear_cache(&self.source_latent_cache);
1710        clear_cache(&self.mask_cache);
1711        clear_cache(&self.control_tensor_cache);
1712        self.active_lora_fingerprint.clear();
1713    }
1714
1715    fn set_on_progress(&mut self, callback: ProgressCallback) {
1716        self.base.set_on_progress(callback);
1717    }
1718
1719    fn clear_on_progress(&mut self) {
1720        self.base.clear_on_progress();
1721    }
1722
1723    fn model_paths(&self) -> Option<&mold_core::ModelPaths> {
1724        Some(&self.base.paths)
1725    }
1726}
1727
1728#[cfg(test)]
1729mod tests {
1730    use super::*;
1731    use crate::engine::InferenceEngine;
1732    use crate::shared_pool::SharedPool;
1733    use safetensors::tensor::{serialize_to_file, Dtype as SafeDtype, TensorView};
1734    use std::collections::HashMap;
1735    use std::path::PathBuf;
1736    use std::sync::{Arc, Mutex};
1737    use tokenizers::models::bpe::BPE;
1738
1739    /// Synthesise a minimal SD1.5-shaped single-file safetensors with one
1740    /// representative key per component bucket. Tensor data is one zero
1741    /// F32 — `loader::single_file::load` is header-only and the
1742    /// constructor doesn't materialise weights.
1743    fn synth_sd15_single_file(name: &str) -> PathBuf {
1744        let path = std::env::temp_dir().join(format!(
1745            "mold-sd15-from-sf-{}-{}-{}.safetensors",
1746            name,
1747            std::process::id(),
1748            std::time::SystemTime::now()
1749                .duration_since(std::time::UNIX_EPOCH)
1750                .unwrap()
1751                .as_nanos(),
1752        ));
1753
1754        let keys: &[&str] = &[
1755            // UNet
1756            "model.diffusion_model.input_blocks.0.0.weight",
1757            "model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight",
1758            // VAE
1759            "first_stage_model.encoder.down.0.block.0.norm1.weight",
1760            "first_stage_model.quant_conv.weight",
1761            // CLIP-L
1762            "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.q_proj.weight",
1763            "cond_stage_model.transformer.text_model.final_layer_norm.weight",
1764        ];
1765
1766        let f32_zero = 0.0f32.to_le_bytes().to_vec();
1767        let buffers: Vec<Vec<u8>> = keys.iter().map(|_| f32_zero.clone()).collect();
1768        let mut tensors: HashMap<String, TensorView<'_>> = HashMap::new();
1769        for (key, buf) in keys.iter().zip(buffers.iter()) {
1770            tensors.insert(
1771                (*key).to_string(),
1772                TensorView::new(SafeDtype::F32, vec![1], buf).unwrap(),
1773            );
1774        }
1775        serialize_to_file(&tensors, &None, &path).unwrap();
1776        path
1777    }
1778
1779    #[test]
1780    fn from_single_file_constructs_for_synthetic_sd15_checkpoint() {
1781        let single_file = synth_sd15_single_file("ok");
1782        // The companion tokenizer is pulled separately (phase 2.7) and is
1783        // not validated by the constructor — its path is just stored.
1784        let tokenizer_path = std::env::temp_dir().join("mold-sd15-tok-stub.json");
1785
1786        let engine = SD15Engine::from_single_file(
1787            "dreamshaper-8".to_string(),
1788            single_file.clone(),
1789            tokenizer_path,
1790            Scheduler::default(),
1791            LoadStrategy::Eager,
1792            0,
1793            None,
1794        )
1795        .expect("constructor must accept a valid SD1.5 single-file layout");
1796
1797        assert_eq!(engine.model_name(), "dreamshaper-8");
1798        assert_eq!(
1799            engine.single_file_path.as_deref(),
1800            Some(single_file.as_path()),
1801            "single-file path must be stashed for the future load() branch",
1802        );
1803        assert!(
1804            !engine.is_loaded(),
1805            "constructor must not eagerly materialise model weights",
1806        );
1807
1808        let _ = std::fs::remove_file(single_file);
1809    }
1810
1811    #[test]
1812    fn sd15_loads_clip_tokenizer_through_shared_pool() {
1813        let dir = tempfile::tempdir().unwrap();
1814        let tokenizer_path = dir.path().join("clip-tokenizer.json");
1815        tokenizers::Tokenizer::new(BPE::default())
1816            .save(&tokenizer_path, false)
1817            .unwrap();
1818        let weights_path = dir.path().join("weights.safetensors");
1819        std::fs::write(&weights_path, b"stub").unwrap();
1820
1821        let shared_pool = Arc::new(Mutex::new(SharedPool::new()));
1822        let pooled = shared_pool
1823            .lock()
1824            .unwrap()
1825            .load_tokenizer(&tokenizer_path)
1826            .unwrap();
1827
1828        let paths = ModelPaths {
1829            transformer: weights_path.clone(),
1830            transformer_shards: Vec::new(),
1831            vae: weights_path.clone(),
1832            spatial_upscaler: None,
1833            temporal_upscaler: None,
1834            distilled_lora: None,
1835            t5_encoder: None,
1836            clip_encoder: Some(weights_path),
1837            t5_tokenizer: None,
1838            clip_tokenizer: Some(tokenizer_path.clone()),
1839            clip_encoder_2: None,
1840            clip_tokenizer_2: None,
1841            text_encoder_files: Vec::new(),
1842            text_tokenizer: None,
1843            decoder: None,
1844        };
1845        let engine = SD15Engine::new(
1846            "sd15-test".to_string(),
1847            paths,
1848            Scheduler::default(),
1849            LoadStrategy::Eager,
1850            0,
1851            Some(shared_pool),
1852        );
1853
1854        let loaded = engine.load_clip_tokenizer(&tokenizer_path).unwrap();
1855
1856        assert!(Arc::ptr_eq(&pooled, &loaded));
1857    }
1858
1859    #[test]
1860    fn sd15_loads_vae_tensors_through_shared_pool() {
1861        let dir = tempfile::tempdir().unwrap();
1862        let vae_path = dir.path().join("vae.safetensors");
1863        let weight = 1.0f32.to_le_bytes();
1864        let mut tensors = HashMap::new();
1865        tensors.insert(
1866            "encoder.conv_in.weight".to_string(),
1867            TensorView::new(SafeDtype::F32, vec![1], &weight).unwrap(),
1868        );
1869        serialize_to_file(&tensors, &None, &vae_path).unwrap();
1870
1871        let shared_pool = Arc::new(Mutex::new(SharedPool::new()));
1872        let pooled = shared_pool
1873            .lock()
1874            .unwrap()
1875            .load_safetensors_cpu_tensors(std::slice::from_ref(&vae_path))
1876            .unwrap()
1877            .unwrap();
1878
1879        let paths = ModelPaths {
1880            transformer: dir.path().join("unet.safetensors"),
1881            transformer_shards: Vec::new(),
1882            vae: vae_path.clone(),
1883            spatial_upscaler: None,
1884            temporal_upscaler: None,
1885            distilled_lora: None,
1886            t5_encoder: None,
1887            clip_encoder: Some(dir.path().join("clip.safetensors")),
1888            t5_tokenizer: None,
1889            clip_tokenizer: Some(dir.path().join("clip-tokenizer.json")),
1890            clip_encoder_2: None,
1891            clip_tokenizer_2: None,
1892            text_encoder_files: Vec::new(),
1893            text_tokenizer: None,
1894            decoder: None,
1895        };
1896        let engine = SD15Engine::new(
1897            "sd15-test".to_string(),
1898            paths,
1899            Scheduler::default(),
1900            LoadStrategy::Eager,
1901            0,
1902            Some(shared_pool),
1903        );
1904
1905        let loaded = engine.load_vae_cpu_tensors().unwrap().unwrap();
1906
1907        assert!(Arc::ptr_eq(&pooled, &loaded));
1908    }
1909
1910    #[test]
1911    fn load_branches_to_single_file_path_and_invokes_candle_constructors() {
1912        // Phase 2.8.5 smoke: `load()` must dispatch to the single-file
1913        // branch, hand a `VarBuilder::from_backend(SingleFileBackend)` to
1914        // candle's `UNet2DConditionModel::new`, and surface an error
1915        // sourced from the SingleFileBackend's tensor lookup (because the
1916        // synthetic checkpoint doesn't carry the full SD1.5 tensor set).
1917        // This confirms the wiring without paying the real-shape model
1918        // construction cost. The full real-shape `load().unwrap()` smoke
1919        // runs in 2.10's <gpu-host> UAT against a real DreamShaper
1920        // checkpoint.
1921        let single_file = synth_sd15_single_file("load-branch");
1922        let tokenizer_stub = std::env::temp_dir().join(format!(
1923            "mold-sd15-tok-stub-{}-{}.json",
1924            std::process::id(),
1925            std::time::SystemTime::now()
1926                .duration_since(std::time::UNIX_EPOCH)
1927                .unwrap()
1928                .as_nanos(),
1929        ));
1930        std::fs::write(&tokenizer_stub, b"").unwrap();
1931
1932        let mut engine = SD15Engine::from_single_file(
1933            "dreamshaper-8".to_string(),
1934            single_file.clone(),
1935            tokenizer_stub.clone(),
1936            Scheduler::Ddim,
1937            LoadStrategy::Eager,
1938            0,
1939            None,
1940        )
1941        .expect("constructor");
1942
1943        std::env::set_var("MOLD_DEVICE", "cpu");
1944        let err = SD15Engine::load(&mut engine).expect_err(
1945            "synthetic single-file checkpoint can't satisfy SD1.5's full tensor set; \
1946             dispatch should land in the candle constructor and fail there",
1947        );
1948        std::env::remove_var("MOLD_DEVICE");
1949
1950        let msg = err.to_string();
1951        // Proof we got past the dispatch: the error originates from the
1952        // single-file remap / SingleFileBackend layer, not from the
1953        // diffusers loader (which would say "transformer/diffusion_pytorch_model.safetensors")
1954        // or from the dispatch itself (which has no error surface — it
1955        // just calls `load_components_single_file`).
1956        assert!(
1957            msg.contains("single-file") || msg.contains("rename rule"),
1958            "expected error from the single-file load layer, got: {msg}",
1959        );
1960
1961        let _ = std::fs::remove_file(single_file);
1962        let _ = std::fs::remove_file(tokenizer_stub);
1963    }
1964
1965    /// Sequential single-file dispatch must hand CLIP-L construction to
1966    /// `SingleFileBackend` (which translates diffusers `text_model.X` keys
1967    /// through `Sd15Remap` into A1111 reads), NOT to candle's diffusers-keyed
1968    /// `build_clip_transformer`. Same contract as the SDXL sibling.
1969    ///
1970    /// Locks the SD1.5 half of bug 1 from
1971    /// `tasks/catalog-run-bridge-option-c-handoff.md`.
1972    #[test]
1973    fn build_clip_single_file_dispatches_through_backend_not_diffusers_loader() {
1974        let single_file = synth_sd15_single_file("seq-clip-dispatch");
1975        let remap = SD15Engine::load_sd15_remap(&single_file).expect("remap");
1976
1977        let result = SD15Engine::build_clip_single_file(
1978            &single_file,
1979            &remap,
1980            &stable_diffusion::clip::Config::v1_5(),
1981            &Device::Cpu,
1982        );
1983
1984        let err = result.expect_err("synthetic CLIP-L is incomplete");
1985        let msg = err.to_string();
1986        assert!(
1987            !msg.contains("cannot find tensor text_model"),
1988            "expected failure from SingleFileBackend ('no rename rule for diffusers key …'), \
1989             not from candle's diffusers `from_mmaped_safetensors`. Sequential dispatch \
1990             is still routing through `build_clip_transformer`. Got: {msg}",
1991        );
1992
1993        let _ = std::fs::remove_file(single_file);
1994    }
1995
1996    /// UNet sequential dispatch parity. SD1.5 sibling of the SDXL UNet
1997    /// dispatch test.
1998    #[test]
1999    fn build_unet_single_file_dispatches_through_backend_not_diffusers_loader() {
2000        let single_file = synth_sd15_single_file("seq-unet-dispatch");
2001        let remap = SD15Engine::load_sd15_remap(&single_file).expect("remap");
2002
2003        let result = SD15Engine::build_unet_single_file(
2004            &single_file,
2005            &remap,
2006            &stable_diffusion::StableDiffusionConfig::v1_5(None, None, None),
2007            &Device::Cpu,
2008            DType::F32,
2009        );
2010
2011        let err = result.expect_err("synthetic UNet is incomplete");
2012        let msg = err.to_string();
2013        assert!(
2014            !msg.contains("cannot find tensor conv_in"),
2015            "expected failure from SingleFileBackend, not diffusers loader. Got: {msg}",
2016        );
2017
2018        let _ = std::fs::remove_file(single_file);
2019    }
2020
2021    /// Real-shape SD1.5 `load()` smoke — full UNet + VAE + CLIP-L
2022    /// construction against a synthetic checkpoint. Demoted to
2023    /// `#[ignore]` per the 2.6 handoff: ~3.4 GB of f32 zeros at
2024    /// SD1.5-correct shapes is over the test budget, and the construction
2025    /// path also can't run today because of the candle-mold private-field
2026    /// blocker (see `load_components_single_file`). Kept as documentation
2027    /// of the eventual UAT contract — 2.10's <gpu-host> UAT runs the
2028    /// real-checkpoint variant.
2029    #[test]
2030    #[ignore]
2031    fn from_single_file_real_shape_load_smoke() {
2032        // Implementation deferred to the candle-fork-accessor follow-up.
2033        // 2.10 UAT replaces this with a real DreamShaper 8 / Realistic
2034        // Vision V5 checkpoint pull + load + 4-step generation.
2035    }
2036
2037    #[test]
2038    fn from_single_file_rejects_missing_file() {
2039        let bogus = std::env::temp_dir().join(format!(
2040            "mold-sd15-from-sf-missing-{}-{}.safetensors",
2041            std::process::id(),
2042            std::time::SystemTime::now()
2043                .duration_since(std::time::UNIX_EPOCH)
2044                .unwrap()
2045                .as_nanos(),
2046        ));
2047
2048        let result = SD15Engine::from_single_file(
2049            "missing".to_string(),
2050            bogus,
2051            std::env::temp_dir().join("mold-sd15-tok-stub.json"),
2052            Scheduler::default(),
2053            LoadStrategy::Eager,
2054            0,
2055            None,
2056        );
2057
2058        assert!(
2059            result.is_err(),
2060            "constructor must surface a missing-file error before deeper parsing",
2061        );
2062    }
2063
2064    // The SD1.5 denoise loop and prompt encoder both gate the unconditional
2065    // pass on `cfg_active(guidance)`. These tests pin the predicate that
2066    // those branches use so a regression to `guidance > 1.0` (which would
2067    // break LCM / Lightning at exactly cfg=1.0) is caught here.
2068
2069    #[test]
2070    fn test_cfg_disabled_at_guidance_1_0() {
2071        assert!(!cfg_active(1.0));
2072    }
2073
2074    #[test]
2075    fn test_cfg_disabled_just_below_1_0() {
2076        assert!(!cfg_active(1.0 - 1e-5));
2077    }
2078
2079    #[test]
2080    fn test_cfg_enabled_at_guidance_1_5() {
2081        assert!(cfg_active(1.5));
2082    }
2083
2084    #[test]
2085    fn test_cfg_enabled_at_guidance_7_5() {
2086        assert!(cfg_active(7.5));
2087    }
2088
2089    /// Regression test for the SD1.5 prompt-cache key bug: keying only on the
2090    /// positive prompt + guidance (as the original code did) returns stale
2091    /// `(uncond_old, cond)` when the user changes just the negative prompt.
2092    /// Mirrors the SD3 / SDXL regression tests in their respective pipelines.
2093    #[test]
2094    fn sd15_prompt_cache_distinguishes_negative_prompt_changes() {
2095        use crate::cache::{cfg_prompt_cache_key, store_cached_tensor};
2096
2097        let cache: Mutex<LruCache<CfgPromptCacheKey, CachedTensor>> =
2098            Mutex::new(LruCache::new(DEFAULT_PROMPT_CACHE_CAPACITY));
2099        let device = Device::Cpu;
2100        let dtype = DType::F32;
2101        let embeddings = candle_core::Tensor::zeros((1, 4), dtype, &device).unwrap();
2102
2103        let key_a = cfg_prompt_cache_key("a cat", "blurry", 7.0);
2104        store_cached_tensor(&cache, key_a.clone(), &embeddings).unwrap();
2105
2106        // Same positive + same guidance, different negative → MUST miss.
2107        let key_b = cfg_prompt_cache_key("a cat", "low quality", 7.0);
2108        let restored = restore_cached_tensor(&cache, &key_b, &device, dtype).unwrap();
2109        assert!(
2110            restored.is_none(),
2111            "different negative prompt must miss the cache (silent-wrong-output bug)",
2112        );
2113
2114        // Same key as the insert → MUST hit.
2115        let restored = restore_cached_tensor(&cache, &key_a, &device, dtype).unwrap();
2116        assert!(
2117            restored.is_some(),
2118            "identical (pos, neg, guidance) must hit",
2119        );
2120    }
2121
2122    /// `lora_stack_fingerprint` must produce equal fingerprints for the same
2123    /// `(path, scale)` pair and distinct ones when either changes — this is
2124    /// the predicate that drives UNet drop on LoRA stack change in eager
2125    /// mode. Same shape as the SDXL test of the same name.
2126    #[test]
2127    fn lora_stack_fingerprint_equality_drives_unet_drop() {
2128        let a = mold_core::LoraWeight {
2129            path: "/a.safetensors".to_string(),
2130            scale: 0.8,
2131        };
2132        let b = mold_core::LoraWeight {
2133            path: "/b.safetensors".to_string(),
2134            scale: 0.4,
2135        };
2136        let same_a = mold_core::LoraWeight {
2137            path: "/a.safetensors".to_string(),
2138            scale: 0.8,
2139        };
2140
2141        // Identical contents → identical fingerprint.
2142        assert_eq!(
2143            lora_stack_fingerprint(&[a.clone(), b.clone()]),
2144            lora_stack_fingerprint(&[same_a.clone(), b.clone()])
2145        );
2146
2147        // Same path but different scale → different fingerprint.
2148        let scaled = mold_core::LoraWeight {
2149            path: "/a.safetensors".to_string(),
2150            scale: 0.81,
2151        };
2152        assert_ne!(
2153            lora_stack_fingerprint(std::slice::from_ref(&a)),
2154            lora_stack_fingerprint(std::slice::from_ref(&scaled))
2155        );
2156
2157        // Reordering yields a different fingerprint (stack order matters).
2158        assert_ne!(
2159            lora_stack_fingerprint(&[a.clone(), b.clone()]),
2160            lora_stack_fingerprint(&[b, a])
2161        );
2162    }
2163}