Skip to main content

mold_inference/sdxl/
pipeline.rs

1use anyhow::{bail, Result};
2use candle_core::{DType, Device, Module, Tensor, D};
3use candle_transformers::models::stable_diffusion;
4use candle_transformers::models::stable_diffusion::schedulers::PredictionType;
5use mold_core::{GenerateRequest, GenerateResponse, ImageData, ModelPaths, Scheduler};
6use std::collections::HashMap;
7use std::path::{Path, PathBuf};
8use std::sync::{Arc, Mutex};
9use std::time::Instant;
10
11use crate::cache::{
12    cfg_prompt_cache_key, clear_cache, get_or_insert_cached_tensor, image_size_cache_key,
13    latent_size_cache_key, restore_cached_tensor, CachedTensor, CfgPromptCacheKey,
14    ImageSizeCacheKey, LatentSizeCacheKey, LruCache, DEFAULT_IMAGE_CACHE_CAPACITY,
15    DEFAULT_PROMPT_CACHE_CAPACITY,
16};
17use crate::cfg_plus_ddim::DdimAlphaSchedule;
18use crate::device::{check_memory_budget, memory_status_string, preflight_memory_check};
19use crate::engine::{cfg_active, rand_seed, resolve_cfg_plus, InferenceEngine, LoadStrategy};
20use crate::engine_base::EngineBase;
21use crate::image::{build_output_metadata, encode_image};
22use crate::progress::{ProgressCallback, ProgressEvent};
23
24/// Loaded SDXL model components, ready for inference.
25struct LoadedSDXL {
26    /// None after being dropped for VAE decode VRAM; reloaded on next generate.
27    unet: Option<stable_diffusion::unet_2d::UNet2DConditionModel>,
28    vae: stable_diffusion::vae::AutoEncoderKL,
29    clip_l: stable_diffusion::clip::ClipTextTransformer,
30    clip_g: stable_diffusion::clip::ClipTextTransformer,
31    tokenizer_l: Arc<tokenizers::Tokenizer>,
32    tokenizer_g: Arc<tokenizers::Tokenizer>,
33    sd_config: stable_diffusion::StableDiffusionConfig,
34    device: Device,
35    /// Device the CLIP-L / CLIP-G weights live on (shared — Tier 1 groups them).
36    clip_device: Device,
37    dtype: DType,
38    /// Effective VAE dtype after `MOLD_VAE_DTYPE` resolution. Captured at
39    /// load time so img2img encode + decode use the matching precision when
40    /// fp32 is forced. May equal `dtype` (default behaviour preserved).
41    vae_dtype: DType,
42}
43
44/// SDXL inference engine backed by candle's stable_diffusion module.
45pub struct SDXLEngine {
46    base: EngineBase<LoadedSDXL>,
47    scheduler: Scheduler,
48    is_turbo: bool,
49    shared_pool: Option<Arc<Mutex<crate::shared_pool::SharedPool>>>,
50    prompt_cache: Mutex<LruCache<CfgPromptCacheKey, CachedTensor>>,
51    source_latent_cache: Mutex<LruCache<ImageSizeCacheKey, CachedTensor>>,
52    mask_cache: Mutex<LruCache<LatentSizeCacheKey, CachedTensor>>,
53    pending_placement: Option<mold_core::types::DevicePlacement>,
54    /// `Some(path)` when the engine was built from a Civitai single-file
55    /// checkpoint via `from_single_file()`. `load()` branches on this to
56    /// wire a custom `SingleFileBackend` (translating diffusers `vb.get`
57    /// calls through `SdxlRemap`, including row-wise slicing for the
58    /// CLIP-G fused QKV slabs) instead of calling candle's per-component
59    /// `build_*` helpers. `None` for diffusers-layout (HF) checkpoints.
60    pub(crate) single_file_path: Option<PathBuf>,
61    /// Per-request LoRA list, set from `req.lora` / `req.loras` in
62    /// `generate()` and cleared after. Read by `build_unet_for_strategy` to
63    /// decide whether to wrap the UNet's `VarBuilder` with `super::lora`'s
64    /// `wrap_backend_with_lora`. Empty in the no-LoRA hot path.
65    pending_loras: Vec<mold_core::LoraWeight>,
66    /// Fingerprint of the LoRA stack currently merged into the loaded UNet
67    /// (eager mode). When `effective_sdxl_loras(req)` produces a different
68    /// fingerprint, `generate_inner` drops the UNet so the next
69    /// `reload_unet_if_needed` call rebuilds it with the new LoRA wrapper.
70    /// Empty when no LoRA is active. This is the SDXL equivalent of
71    /// `FluxEngine::active_lora`.
72    active_lora_fingerprint: Vec<(String, u64)>,
73}
74
75/// Compute a stable fingerprint for a LoRA stack: ordered list of
76/// `(path, scale_bits)`. Two stacks are equal iff they merge to the same
77/// transformer, so a comparison drives UNet reload on any change (swap,
78/// scale, add, remove, reorder).
79fn lora_stack_fingerprint(loras: &[mold_core::LoraWeight]) -> Vec<(String, u64)> {
80    loras
81        .iter()
82        .map(|w| (w.path.clone(), w.scale.to_bits()))
83        .collect()
84}
85
86/// VAE scaling factor for standard SDXL models.
87const VAE_SCALE_STANDARD: f64 = 0.18215;
88/// VAE scaling factor for SDXL Turbo models.
89const VAE_SCALE_TURBO: f64 = 0.13025;
90
91fn resolve_sdxl_vae_dtype(default_dtype: DType, single_file: bool) -> DType {
92    let default = if single_file {
93        // Several SDXL Civitai single-file finetunes decode to all-black
94        // images when their baked VAE runs in fp16. Use fp32 by default for
95        // this path while still honoring an explicit MOLD_VAE_DTYPE override.
96        DType::F32
97    } else {
98        default_dtype
99    };
100    crate::device::resolve_vae_dtype(default)
101}
102
103impl SDXLEngine {
104    pub fn new(
105        model_name: String,
106        paths: ModelPaths,
107        scheduler: Scheduler,
108        is_turbo: bool,
109        load_strategy: LoadStrategy,
110        gpu_ordinal: usize,
111        shared_pool: Option<Arc<Mutex<crate::shared_pool::SharedPool>>>,
112    ) -> Self {
113        Self {
114            base: EngineBase::new(model_name, paths, load_strategy, gpu_ordinal),
115            scheduler,
116            is_turbo,
117            shared_pool,
118            prompt_cache: Mutex::new(LruCache::new(DEFAULT_PROMPT_CACHE_CAPACITY)),
119            source_latent_cache: Mutex::new(LruCache::new(DEFAULT_IMAGE_CACHE_CAPACITY)),
120            mask_cache: Mutex::new(LruCache::new(DEFAULT_IMAGE_CACHE_CAPACITY)),
121            pending_placement: None,
122            single_file_path: None,
123            pending_loras: Vec::new(),
124            active_lora_fingerprint: Vec::new(),
125        }
126    }
127
128    /// Construct an SDXL engine from a Civitai single-file `.safetensors`
129    /// checkpoint.
130    ///
131    /// Header-parses the file via `loader::single_file::load(_, Family::Sdxl)`,
132    /// validates that the SDXL A1111 → diffusers rename rules cover its
133    /// UNet / VAE / CLIP-L / CLIP-G tensor keys via
134    /// `loader::sdxl_keys::build_sdxl_remap(_)`, and stashes the path in
135    /// `single_file_path`. The actual UNet / VAE / CLIP-L / CLIP-G
136    /// materialisation (with a custom `SimpleBackend` that translates each
137    /// diffusers `vb.get(name)` into the corresponding A1111 key — including
138    /// row-wise slicing for the CLIP-G fused QKV slabs — inside the mmap'd
139    /// single file) lands in a downstream phase. This constructor does
140    /// **not** build the model.
141    ///
142    /// `clip_l_tokenizer` and `clip_g_tokenizer` are paths to companion-pulled
143    /// tokenizer assets (phase 2.7); tokenizers never live inside the
144    /// single-file checkpoint.
145    ///
146    /// `is_turbo` is threaded through from the manifest / model config —
147    /// Civitai SDXL Turbo checkpoints are not structurally distinguishable
148    /// from standard SDXL at the safetensors-header level, so the caller
149    /// (the 2.6 factory) makes the call based on `model_cfg.is_turbo` (or
150    /// the `model_name.contains("turbo")` fallback). Drives the
151    /// `VAE_SCALE_TURBO` / `EulerAncestral` defaults at load time via
152    /// `sdxl_config()`.
153    #[allow(clippy::too_many_arguments)]
154    pub fn from_single_file(
155        model_name: String,
156        single_file_path: PathBuf,
157        clip_l_tokenizer: PathBuf,
158        clip_g_tokenizer: PathBuf,
159        scheduler: Scheduler,
160        is_turbo: bool,
161        load_strategy: LoadStrategy,
162        gpu_ordinal: usize,
163        shared_pool: Option<Arc<Mutex<crate::shared_pool::SharedPool>>>,
164    ) -> Result<Self> {
165        if !single_file_path.exists() {
166            bail!(
167                "single-file checkpoint not found: {}",
168                single_file_path.display()
169            );
170        }
171
172        // Header-parse and validate the SDXL layout. `single_file::load`
173        // is header-only — no tensor data is mmap'd here.
174        let bundle = crate::loader::single_file::load(
175            &single_file_path,
176            mold_catalog::families::Family::Sdxl,
177        )?;
178
179        // Validate that every diffusers key the future SimpleBackend will
180        // request resolves to a real A1111 source — including the CLIP-G
181        // fused QKV slabs (each `attn.in_proj_*` key expands into three
182        // diffusers entries via `RenameOutput::FusedSlice`). Catches
183        // malformed checkpoints at construction time rather than deep in
184        // the eventual `load()` call.
185        let _remap = crate::loader::sdxl_keys::build_sdxl_remap(&bundle)?;
186
187        // `ModelPaths` is a diffusers-layout view; for single-file the
188        // transformer / vae / clip_encoder / clip_encoder_2 all materialise
189        // from the same checkpoint at load time. The real branch lives on
190        // the new `single_file_path` field — future `load()` consults it
191        // before falling through to the diffusers `build_*` helpers.
192        let paths = ModelPaths {
193            transformer: single_file_path.clone(),
194            transformer_shards: Vec::new(),
195            vae: single_file_path.clone(),
196            spatial_upscaler: None,
197            temporal_upscaler: None,
198            distilled_lora: None,
199            t5_encoder: None,
200            clip_encoder: Some(single_file_path.clone()),
201            t5_tokenizer: None,
202            clip_tokenizer: Some(clip_l_tokenizer),
203            clip_encoder_2: Some(single_file_path.clone()),
204            clip_tokenizer_2: Some(clip_g_tokenizer),
205            text_encoder_files: Vec::new(),
206            text_tokenizer: None,
207            decoder: None,
208        };
209
210        // `is_turbo` is threaded through from the 2.6 factory based on the
211        // manifest's `is_turbo` flag (or the `model_name.contains("turbo")`
212        // fallback). Civitai SDXL Turbo checkpoints aren't structurally
213        // distinguishable from standard SDXL at the safetensors-header
214        // level — turbo-vs-standard is purely a load-time concern (VAE
215        // scale, scheduler defaults), so the constructor accepts it
216        // verbatim and stashes it for `sdxl_config()` to read.
217        Ok(Self {
218            base: EngineBase::new(model_name, paths, load_strategy, gpu_ordinal),
219            scheduler,
220            is_turbo,
221            shared_pool,
222            prompt_cache: Mutex::new(LruCache::new(DEFAULT_PROMPT_CACHE_CAPACITY)),
223            source_latent_cache: Mutex::new(LruCache::new(DEFAULT_IMAGE_CACHE_CAPACITY)),
224            mask_cache: Mutex::new(LruCache::new(DEFAULT_IMAGE_CACHE_CAPACITY)),
225            pending_placement: None,
226            single_file_path: Some(single_file_path),
227            pending_loras: Vec::new(),
228            active_lora_fingerprint: Vec::new(),
229        })
230    }
231
232    /// Validate and return required SDXL paths.
233    fn validate_paths(
234        &self,
235    ) -> Result<(
236        std::path::PathBuf,
237        std::path::PathBuf,
238        std::path::PathBuf,
239        std::path::PathBuf,
240    )> {
241        let clip_encoder = self
242            .base
243            .paths
244            .clip_encoder
245            .as_ref()
246            .ok_or_else(|| anyhow::anyhow!("CLIP-L encoder path required for SDXL models"))?
247            .clone();
248        let clip_tokenizer = self
249            .base
250            .paths
251            .clip_tokenizer
252            .as_ref()
253            .ok_or_else(|| anyhow::anyhow!("CLIP-L tokenizer path required for SDXL models"))?
254            .clone();
255        let clip_encoder_2 = self
256            .base
257            .paths
258            .clip_encoder_2
259            .as_ref()
260            .ok_or_else(|| anyhow::anyhow!("CLIP-G encoder path required for SDXL models"))?
261            .clone();
262        let clip_tokenizer_2 = self
263            .base
264            .paths
265            .clip_tokenizer_2
266            .as_ref()
267            .ok_or_else(|| anyhow::anyhow!("CLIP-G tokenizer path required for SDXL models"))?
268            .clone();
269
270        for (label, path) in [
271            ("transformer (UNet)", &self.base.paths.transformer),
272            ("vae", &self.base.paths.vae),
273            ("clip_encoder (CLIP-L)", &clip_encoder),
274            ("clip_tokenizer (CLIP-L)", &clip_tokenizer),
275            ("clip_encoder_2 (CLIP-G)", &clip_encoder_2),
276            ("clip_tokenizer_2 (CLIP-G)", &clip_tokenizer_2),
277        ] {
278            if !path.exists() {
279                bail!("{label} file not found: {}", path.display());
280            }
281        }
282
283        Ok((
284            clip_encoder,
285            clip_tokenizer,
286            clip_encoder_2,
287            clip_tokenizer_2,
288        ))
289    }
290
291    fn load_clip_tokenizer(
292        &self,
293        clip_tokenizer: &std::path::Path,
294        label: &str,
295    ) -> Result<Arc<tokenizers::Tokenizer>> {
296        if let Some(ref pool) = self.shared_pool {
297            return pool.lock().unwrap().load_tokenizer(clip_tokenizer);
298        }
299        Ok(Arc::new(
300            tokenizers::Tokenizer::from_file(clip_tokenizer)
301                .map_err(|e| anyhow::anyhow!("failed to load {label} tokenizer: {e}"))?,
302        ))
303    }
304
305    /// Create the SDXL config.
306    fn sd_config(&self) -> stable_diffusion::StableDiffusionConfig {
307        if self.is_turbo {
308            stable_diffusion::StableDiffusionConfig::sdxl_turbo(None, None, None)
309        } else {
310            stable_diffusion::StableDiffusionConfig::sdxl(None, None, None)
311        }
312    }
313
314    /// Reload UNet if it was dropped after VAE decode.
315    fn reload_unet_if_needed(&mut self) -> Result<()> {
316        let needs_reload = self
317            .base
318            .loaded
319            .as_ref()
320            .map(|l| l.unet.is_none())
321            .unwrap_or(false);
322
323        if needs_reload {
324            let sd_config = self.sd_config();
325            let loaded = self.base.loaded.as_ref().unwrap();
326            let device = loaded.device.clone();
327            let dtype = loaded.dtype;
328            let _ = loaded;
329
330            self.base.progress.stage_start("Reloading UNet (GPU)");
331            let reload_start = Instant::now();
332            let unet = self.build_unet_for_strategy(&sd_config, &device, dtype)?;
333            self.base.loaded.as_mut().unwrap().unet = Some(unet);
334            self.base
335                .progress
336                .stage_done("Reloading UNet (GPU)", reload_start.elapsed());
337        }
338        Ok(())
339    }
340
341    /// UNet load that branches on `single_file_path` — Civitai single-file
342    /// checkpoints (A1111 naming) get the `SingleFileBackend` dispatch,
343    /// diffusers-layout falls through to candle's `build_unet`. Used by
344    /// every UNet load site (eager `load`, sequential `generate_sequential`,
345    /// and `reload_unet_if_needed`) so the branch logic lives in exactly
346    /// one place.
347    ///
348    /// When `self.pending_loras` is non-empty, the underlying tensor source
349    /// (mmap'd diffusers safetensors or `SingleFileBackend`) is wrapped in
350    /// `super::lora::wrap_backend_with_lora` so the UNet construction loads
351    /// `W' = W + scale·(B @ A)` for every LoRA-targeted layer. The wrapper
352    /// is transparent to `UNet2DConditionModel::new`.
353    fn build_unet_for_strategy(
354        &self,
355        sd_config: &stable_diffusion::StableDiffusionConfig,
356        device: &Device,
357        dtype: DType,
358    ) -> Result<stable_diffusion::unet_2d::UNet2DConditionModel> {
359        let has_lora = !self.pending_loras.is_empty();
360        if let Some(single_file) = self.single_file_path.as_ref() {
361            let remap = Self::load_sdxl_remap(single_file)?;
362            if has_lora {
363                self.build_unet_single_file_with_lora(single_file, &remap, sd_config, device, dtype)
364            } else {
365                Self::build_unet_single_file(single_file, &remap, sd_config, device, dtype)
366            }
367        } else if has_lora {
368            self.build_unet_diffusers_with_lora(sd_config, device, dtype)
369        } else {
370            Ok(sd_config.build_unet(&self.base.paths.transformer, device, 4, false, dtype)?)
371        }
372    }
373
374    /// Build the UNet from a diffusers-layout single safetensors file with
375    /// LoRA wrappers active. Mirrors candle's `build_unet`: open the mmap,
376    /// wrap it in a `SimpleBackend`, layer `SdxlLoraBackend` on top, and feed
377    /// the resulting `VarBuilder` to `UNet2DConditionModel::new`.
378    fn build_unet_diffusers_with_lora(
379        &self,
380        sd_config: &stable_diffusion::StableDiffusionConfig,
381        device: &Device,
382        dtype: DType,
383    ) -> Result<stable_diffusion::unet_2d::UNet2DConditionModel> {
384        use candle_core::safetensors::MmapedSafetensors;
385        use candle_nn::VarBuilder;
386
387        let st = unsafe { MmapedSafetensors::multi(&[&self.base.paths.transformer])? };
388
389        struct MmapBackend {
390            st: MmapedSafetensors,
391        }
392        impl candle_nn::var_builder::SimpleBackend for MmapBackend {
393            fn get(
394                &self,
395                _s: candle_core::Shape,
396                name: &str,
397                _h: candle_nn::Init,
398                dtype: DType,
399                dev: &Device,
400            ) -> candle_core::Result<Tensor> {
401                let t = self.st.load(name, dev)?;
402                if t.dtype() != dtype {
403                    t.to_dtype(dtype)
404                } else {
405                    Ok(t)
406                }
407            }
408            fn get_unchecked(
409                &self,
410                name: &str,
411                dtype: DType,
412                dev: &Device,
413            ) -> candle_core::Result<Tensor> {
414                let t = self.st.load(name, dev)?;
415                if t.dtype() != dtype {
416                    t.to_dtype(dtype)
417                } else {
418                    Ok(t)
419                }
420            }
421            fn contains_tensor(&self, name: &str) -> bool {
422                self.st.get(name).is_ok()
423            }
424        }
425        let inner: Box<dyn candle_nn::var_builder::SimpleBackend> = Box::new(MmapBackend { st });
426        let wrapped = self.wrap_with_loras(inner)?;
427        let vb = VarBuilder::from_backend(wrapped, dtype, device.clone());
428        Ok(stable_diffusion::unet_2d::UNet2DConditionModel::new(
429            vb,
430            4,
431            4,
432            false,
433            sd_config.unet().clone(),
434        )?)
435    }
436
437    /// Same as `build_unet_single_file` but with the LoRA wrapper layered on
438    /// top of `SingleFileBackend`. The SDXL `SingleFileBackend` translates
439    /// diffusers candle keys onto the mmap'd A1111 checkpoint; the LoRA
440    /// wrapper then intercepts the merged tensor and adds the per-layer
441    /// delta before it lands on the UNet constructor.
442    fn build_unet_single_file_with_lora(
443        &self,
444        single_file: &std::path::Path,
445        remap: &crate::loader::SdxlRemap,
446        sd_config: &stable_diffusion::StableDiffusionConfig,
447        device: &Device,
448        dtype: DType,
449    ) -> Result<stable_diffusion::unet_2d::UNet2DConditionModel> {
450        use crate::loader::SingleFileBackend;
451        use candle_nn::VarBuilder;
452
453        let backend = SingleFileBackend::from_sdxl_unet(single_file, remap)?;
454        let inner: Box<dyn candle_nn::var_builder::SimpleBackend> = Box::new(backend);
455        let wrapped = self.wrap_with_loras(inner)?;
456        let vb = VarBuilder::from_backend(wrapped, dtype, device.clone());
457        Ok(stable_diffusion::unet_2d::UNet2DConditionModel::new(
458            vb,
459            4,
460            4,
461            false,
462            sd_config.unet().clone(),
463        )?)
464    }
465
466    /// Wrap an `inner` SimpleBackend with the SDXL LoRA backend. Resolves the
467    /// `pending_loras` list into `LoraSpec`s (parsed-LoRA cache hits keep
468    /// adapter parsing cheap across requests).
469    fn wrap_with_loras(
470        &self,
471        inner: Box<dyn candle_nn::var_builder::SimpleBackend>,
472    ) -> Result<Box<dyn candle_nn::var_builder::SimpleBackend>> {
473        let adapters = super::lora::load_lora_adapters(&self.pending_loras, &self.base.progress)?;
474        let specs: Vec<super::lora::SdxlLoraSpec<'_>> = adapters
475            .iter()
476            .zip(self.pending_loras.iter())
477            .map(|(adapter, w)| super::lora::SdxlLoraSpec {
478                adapter: adapter.as_ref(),
479                scale: w.scale,
480                path_hash: super::lora::lora_path_hash(&w.path),
481            })
482            .collect();
483        super::lora::wrap_backend_with_lora(inner, &specs, &self.base.progress, None)
484    }
485
486    /// VAE load with the same single-file vs diffusers branch as
487    /// `build_unet_for_strategy`. Sequential SDXL loads the VAE twice
488    /// (once for img2img encode, once for post-denoise decode), so this
489    /// helper keeps the branch logic in a single place.
490    fn build_vae_for_strategy(
491        &self,
492        sd_config: &stable_diffusion::StableDiffusionConfig,
493        device: &Device,
494        dtype: DType,
495    ) -> Result<stable_diffusion::vae::AutoEncoderKL> {
496        if let Some(single_file) = self.single_file_path.as_ref() {
497            let remap = Self::load_sdxl_remap(single_file)?;
498            Self::build_vae_single_file(single_file, &remap, sd_config, device, dtype)
499        } else {
500            self.build_vae_diffusers(sd_config, device, dtype)
501        }
502    }
503
504    #[cfg(test)]
505    fn load_vae_cpu_tensors(&self) -> Result<Option<Arc<HashMap<String, Tensor>>>> {
506        self.load_vae_cpu_tensors_for_path(&self.base.paths.vae)
507    }
508
509    fn load_vae_cpu_tensors_for_path(
510        &self,
511        vae_path: &Path,
512    ) -> Result<Option<Arc<HashMap<String, Tensor>>>> {
513        let Some(shared_pool) = &self.shared_pool else {
514            return Ok(None);
515        };
516        shared_pool
517            .lock()
518            .unwrap()
519            .load_safetensors_cpu_tensors(std::slice::from_ref(&vae_path))
520    }
521
522    fn load_vae_var_builder<'a>(
523        &self,
524        vae_path: &Path,
525        dtype: DType,
526        device: &Device,
527        component: &str,
528    ) -> Result<candle_nn::VarBuilder<'a>> {
529        if let Some(tensors) = self.load_vae_cpu_tensors_for_path(vae_path)? {
530            return Ok(crate::encoders::park::varbuilder_from_parked(
531                tensors.as_ref(),
532                dtype,
533                device,
534            ));
535        }
536
537        crate::weight_loader::load_safetensors_with_progress(
538            std::slice::from_ref(&vae_path),
539            dtype,
540            device,
541            component,
542            &self.base.progress,
543        )
544    }
545
546    fn build_vae_diffusers(
547        &self,
548        sd_config: &stable_diffusion::StableDiffusionConfig,
549        device: &Device,
550        dtype: DType,
551    ) -> Result<stable_diffusion::vae::AutoEncoderKL> {
552        let vb = self.load_vae_var_builder(&self.base.paths.vae, dtype, device, "VAE")?;
553        Ok(stable_diffusion::vae::AutoEncoderKL::new(
554            vb,
555            3,
556            3,
557            sd_config.autoencoder().clone(),
558        )?)
559    }
560
561    /// Load all SDXL model components (Eager mode).
562    ///
563    /// Two materialisation paths share this entry point — diffusers-layout
564    /// (separate component files via candle's `build_*` helpers) and
565    /// single-file (Civitai checkpoint via a custom `SingleFileBackend`
566    /// that translates each diffusers `vb.get(name)` into mmap'd reads,
567    /// including row-wise slicing for the CLIP-G fused QKV slabs). The
568    /// branch is on `self.single_file_path`.
569    ///
570    /// On error, `self.base.loaded` remains `None` — all components are
571    /// assembled into local variables and only stored on success, so
572    /// partial loads cannot leave the engine in an inconsistent state.
573    pub fn load(&mut self) -> Result<()> {
574        if self.base.loaded.is_some() {
575            return Ok(());
576        }
577
578        // Sequential mode defers loading to generate_sequential()
579        if self.base.load_strategy == LoadStrategy::Sequential {
580            return Ok(());
581        }
582
583        let (clip_encoder, clip_tokenizer, clip_encoder_2, clip_tokenizer_2) =
584            self.validate_paths()?;
585
586        tracing::info!(model = %self.base.model_name, "loading SDXL model components...");
587
588        let device = crate::device::create_device(self.base.gpu_ordinal, &self.base.progress)?;
589        let dtype = if crate::device::is_gpu(&device) {
590            DType::F16
591        } else {
592            DType::F32
593        };
594
595        let sd_config = self.sd_config();
596
597        // Tier 1: honor `placement.text_encoders` for both CLIPs as a group.
598        let tier1 = self
599            .pending_placement
600            .as_ref()
601            .map(|p| p.text_encoders)
602            .unwrap_or_default();
603        let clip_device = crate::device::resolve_device(Some(tier1), || Ok(device.clone()))?;
604
605        // Resolve VAE precision once at load — see LoadedSDXL::vae_dtype.
606        let vae_dtype = resolve_sdxl_vae_dtype(dtype, self.single_file_path.is_some());
607        let (unet, vae, clip_l, clip_g) = if let Some(single_file) = self.single_file_path.clone() {
608            self.load_components_single_file(
609                &single_file,
610                &sd_config,
611                &device,
612                &clip_device,
613                dtype,
614                vae_dtype,
615            )?
616        } else {
617            self.load_components_diffusers(
618                &clip_encoder,
619                &clip_encoder_2,
620                &sd_config,
621                &device,
622                &clip_device,
623                dtype,
624                vae_dtype,
625            )?
626        };
627
628        let tokenizer_l = self.load_clip_tokenizer(&clip_tokenizer, "CLIP-L")?;
629        let tokenizer_g = self.load_clip_tokenizer(&clip_tokenizer_2, "CLIP-G")?;
630
631        self.base.loaded = Some(LoadedSDXL {
632            unet: Some(unet),
633            vae,
634            clip_l,
635            clip_g,
636            tokenizer_l,
637            tokenizer_g,
638            sd_config,
639            device,
640            clip_device,
641            dtype,
642            vae_dtype,
643        });
644
645        tracing::info!(model = %self.base.model_name, "all SDXL components loaded successfully");
646        Ok(())
647    }
648
649    /// Diffusers-layout component loader (existing pre-2.6 path).
650    #[allow(clippy::too_many_arguments)]
651    fn load_components_diffusers(
652        &mut self,
653        clip_encoder: &std::path::Path,
654        clip_encoder_2: &std::path::Path,
655        sd_config: &stable_diffusion::StableDiffusionConfig,
656        device: &Device,
657        clip_device: &Device,
658        dtype: DType,
659        vae_dtype: DType,
660    ) -> Result<(
661        stable_diffusion::unet_2d::UNet2DConditionModel,
662        stable_diffusion::vae::AutoEncoderKL,
663        stable_diffusion::clip::ClipTextTransformer,
664        stable_diffusion::clip::ClipTextTransformer,
665    )> {
666        self.base.progress.stage_start("Loading UNet (GPU)");
667        let unet_start = Instant::now();
668        let unet = sd_config.build_unet(&self.base.paths.transformer, device, 4, false, dtype)?;
669        self.base
670            .progress
671            .stage_done("Loading UNet (GPU)", unet_start.elapsed());
672
673        self.base.progress.stage_start("Loading VAE (GPU)");
674        let vae_start = Instant::now();
675        let vae = self.build_vae_diffusers(sd_config, device, vae_dtype)?;
676        self.base
677            .progress
678            .stage_done("Loading VAE (GPU)", vae_start.elapsed());
679
680        self.base.progress.stage_start("Loading CLIP-L encoder");
681        let clip_l_start = Instant::now();
682        let clip_l = stable_diffusion::build_clip_transformer(
683            &sd_config.clip,
684            clip_encoder,
685            clip_device,
686            DType::F32,
687        )?;
688        self.base
689            .progress
690            .stage_done("Loading CLIP-L encoder", clip_l_start.elapsed());
691
692        self.base.progress.stage_start("Loading CLIP-G encoder");
693        let clip_g_start = Instant::now();
694        let clip2_config = sd_config
695            .clip2
696            .as_ref()
697            .ok_or_else(|| anyhow::anyhow!("SDXL config missing clip2 configuration"))?;
698        let clip_g = stable_diffusion::build_clip_transformer(
699            clip2_config,
700            clip_encoder_2,
701            clip_device,
702            DType::F32,
703        )?;
704        self.base
705            .progress
706            .stage_done("Loading CLIP-G encoder", clip_g_start.elapsed());
707
708        Ok((unet, vae, clip_l, clip_g))
709    }
710
711    /// Single-file (Civitai) component loader (phase 2.6 + 2.8.5).
712    ///
713    /// Header-parses the checkpoint, builds the diffusers→A1111 remap
714    /// (incl. CLIP-G `RenameOutput::FusedSlice` entries for the OpenCLIP
715    /// `attn.in_proj_*` slabs), wraps it in `SingleFileBackend`, and feeds
716    /// four `VarBuilder::from_backend(SingleFileBackend)` instances to
717    /// candle's per-component constructors: UNet, VAE, CLIP-L, CLIP-G.
718    /// All four read from the same single-file mmap.
719    ///
720    /// Reaches into `sd_config.unet()` / `.autoencoder()` accessors exposed
721    /// by candle-transformers-mold 0.9.12 (utensils/candle PR #1).
722    fn load_components_single_file(
723        &mut self,
724        single_file: &std::path::Path,
725        sd_config: &stable_diffusion::StableDiffusionConfig,
726        device: &Device,
727        clip_device: &Device,
728        dtype: DType,
729        vae_dtype: DType,
730    ) -> Result<(
731        stable_diffusion::unet_2d::UNet2DConditionModel,
732        stable_diffusion::vae::AutoEncoderKL,
733        stable_diffusion::clip::ClipTextTransformer,
734        stable_diffusion::clip::ClipTextTransformer,
735    )> {
736        let remap = Self::load_sdxl_remap(single_file)?;
737
738        self.base.progress.stage_start("Loading UNet (single-file)");
739        let unet_start = Instant::now();
740        let unet = Self::build_unet_single_file(single_file, &remap, sd_config, device, dtype)?;
741        self.base
742            .progress
743            .stage_done("Loading UNet (single-file)", unet_start.elapsed());
744
745        self.base.progress.stage_start("Loading VAE (single-file)");
746        let vae_start = Instant::now();
747        let vae = Self::build_vae_single_file(single_file, &remap, sd_config, device, vae_dtype)?;
748        self.base
749            .progress
750            .stage_done("Loading VAE (single-file)", vae_start.elapsed());
751
752        self.base
753            .progress
754            .stage_start("Loading CLIP-L (single-file)");
755        let clip_l_start = Instant::now();
756        let clip_l =
757            Self::build_clip_l_single_file(single_file, &remap, &sd_config.clip, clip_device)?;
758        self.base
759            .progress
760            .stage_done("Loading CLIP-L (single-file)", clip_l_start.elapsed());
761
762        self.base
763            .progress
764            .stage_start("Loading CLIP-G (single-file)");
765        let clip_g_start = Instant::now();
766        let clip2_config = sd_config
767            .clip2
768            .as_ref()
769            .ok_or_else(|| anyhow::anyhow!("SDXL config missing clip2 configuration"))?;
770        let clip_g =
771            Self::build_clip_g_single_file(single_file, &remap, clip2_config, clip_device)?;
772        self.base
773            .progress
774            .stage_done("Loading CLIP-G (single-file)", clip_g_start.elapsed());
775
776        Ok((unet, vae, clip_l, clip_g))
777    }
778
779    /// Header-parse the single-file checkpoint and build the SDXL
780    /// diffusers→A1111 remap. Cheap (no tensor data is mmap'd) — sequential
781    /// reload calls this each time a component is reloaded after dropping.
782    fn load_sdxl_remap(single_file: &std::path::Path) -> Result<crate::loader::SdxlRemap> {
783        use crate::loader::{build_sdxl_remap, single_file as single_file_loader};
784        use mold_catalog::families::Family;
785        let bundle = single_file_loader::load(single_file, Family::Sdxl)
786            .map_err(|e| anyhow::anyhow!("partition single-file SDXL checkpoint: {e}"))?;
787        build_sdxl_remap(&bundle)
788            .map_err(|e| anyhow::anyhow!("build SDXL diffusers→A1111 remap: {e}"))
789    }
790
791    /// Build a UNet from a Civitai single-file checkpoint via
792    /// `SingleFileBackend`. Used by both eager `load_components_single_file`
793    /// and sequential `generate_sequential` so the dispatch is identical
794    /// across modes.
795    fn build_unet_single_file(
796        single_file: &std::path::Path,
797        remap: &crate::loader::SdxlRemap,
798        sd_config: &stable_diffusion::StableDiffusionConfig,
799        device: &Device,
800        dtype: DType,
801    ) -> Result<stable_diffusion::unet_2d::UNet2DConditionModel> {
802        use crate::loader::SingleFileBackend;
803        use candle_nn::VarBuilder;
804        let backend = SingleFileBackend::from_sdxl_unet(single_file, remap)?;
805        let vb = VarBuilder::from_backend(Box::new(backend), dtype, device.clone());
806        Ok(stable_diffusion::unet_2d::UNet2DConditionModel::new(
807            vb,
808            4,
809            4,
810            false,
811            sd_config.unet().clone(),
812        )?)
813    }
814
815    /// Build a VAE from a Civitai single-file checkpoint via a VAE-scoped
816    /// `SingleFileBackend`.
817    fn build_vae_single_file(
818        single_file: &std::path::Path,
819        remap: &crate::loader::SdxlRemap,
820        sd_config: &stable_diffusion::StableDiffusionConfig,
821        device: &Device,
822        dtype: DType,
823    ) -> Result<stable_diffusion::vae::AutoEncoderKL> {
824        use crate::loader::SingleFileBackend;
825        use candle_nn::VarBuilder;
826        let backend = SingleFileBackend::from_sdxl_vae(single_file, remap)?;
827        let vb = VarBuilder::from_backend(Box::new(backend), dtype, device.clone());
828        Ok(stable_diffusion::vae::AutoEncoderKL::new(
829            vb,
830            3,
831            3,
832            sd_config.autoencoder().clone(),
833        )?)
834    }
835
836    /// Build CLIP-L from a Civitai single-file checkpoint via a CLIP-L-scoped
837    /// `SingleFileBackend`. Scoping is critical: CLIP-L and CLIP-G produce
838    /// the same diffusers keys (e.g. `text_model.embeddings.token_embedding.weight`
839    /// and every encoder layer's `self_attn.{q,k,v,out}_proj.weight`), so
840    /// an all-in-one entries map collapses them and CLIP-L's
841    /// `ClipTextTransformer` would materialise with CLIP-G's `[vocab, 1280]`
842    /// weights instead of its own `[vocab, 768]`. The next
843    /// `Embedding::forward` reshape then blows up with `lhs: [77, 1280],
844    /// rhs: [1, 77, 768]`. See
845    /// `loader::single_file_backend::tests::sdxl_clip_l_scoped_backend_returns_clip_l_tensor_when_keys_collide_with_clip_g`.
846    fn build_clip_l_single_file(
847        single_file: &std::path::Path,
848        remap: &crate::loader::SdxlRemap,
849        clip_config: &stable_diffusion::clip::Config,
850        clip_device: &Device,
851    ) -> Result<stable_diffusion::clip::ClipTextTransformer> {
852        use crate::loader::SingleFileBackend;
853        use candle_nn::VarBuilder;
854        let backend = SingleFileBackend::from_sdxl_clip_l(single_file, remap)?;
855        let vb = VarBuilder::from_backend(Box::new(backend), DType::F32, clip_device.clone());
856        Ok(stable_diffusion::clip::ClipTextTransformer::new(
857            vb,
858            clip_config,
859        )?)
860    }
861
862    /// Build CLIP-G from a Civitai single-file checkpoint via a CLIP-G-scoped
863    /// `SingleFileBackend`. Scoped factory keeps CLIP-G's FusedSlice entries
864    /// for the OpenCLIP `attn.in_proj_*` slabs and excludes CLIP-L's overlap.
865    fn build_clip_g_single_file(
866        single_file: &std::path::Path,
867        remap: &crate::loader::SdxlRemap,
868        clip_config: &stable_diffusion::clip::Config,
869        clip_device: &Device,
870    ) -> Result<stable_diffusion::clip::ClipTextTransformer> {
871        use crate::loader::SingleFileBackend;
872        use candle_nn::VarBuilder;
873        let backend = SingleFileBackend::from_sdxl_clip_g(single_file, remap)?;
874        let vb = VarBuilder::from_backend(Box::new(backend), DType::F32, clip_device.clone());
875        Ok(stable_diffusion::clip::ClipTextTransformer::new(
876            vb,
877            clip_config,
878        )?)
879    }
880
881    /// Tokenize a prompt for a CLIP encoder, padding/truncating to max_len tokens.
882    fn tokenize(
883        tokenizer: &tokenizers::Tokenizer,
884        prompt: &str,
885        max_len: usize,
886        device: &Device,
887    ) -> Result<Tensor> {
888        let encoding = tokenizer
889            .encode(prompt, true)
890            .map_err(|e| anyhow::anyhow!("tokenization failed: {e}"))?;
891        let mut ids = encoding.get_ids().to_vec();
892        ids.truncate(max_len);
893        // Pad with 0s (EOS/PAD token for CLIP)
894        while ids.len() < max_len {
895            ids.push(0);
896        }
897        let ids = ids.into_iter().map(|i| i as i64).collect::<Vec<_>>();
898        Ok(Tensor::new(ids, device)?.unsqueeze(0)?)
899    }
900
901    /// Run the denoising loop (shared between eager and sequential).
902    ///
903    /// `start_step` allows starting from a later timestep for img2img (0 = full txt2img).
904    #[allow(clippy::too_many_arguments)]
905    #[allow(clippy::too_many_arguments)]
906    fn denoise_loop(
907        &self,
908        unet: &stable_diffusion::unet_2d::UNet2DConditionModel,
909        text_embeddings: &Tensor,
910        sched: Scheduler,
911        latents: &mut Tensor,
912        guidance: f64,
913        cfg_plus: bool,
914        steps: u32,
915        start_step: usize,
916        inpaint_ctx: Option<&crate::img_utils::InpaintContext>,
917    ) -> Result<()> {
918        let use_cfg = cfg_active(guidance);
919        let mut scheduler = crate::scheduler::build_scheduler(
920            sched,
921            steps as usize,
922            PredictionType::Epsilon,
923            self.is_turbo,
924        )?;
925        let timesteps = scheduler.timesteps().to_vec();
926        let active_timesteps = &timesteps[start_step..];
927
928        // CFG++ requires the doubled `[uncond, cond]` forward (so we can read
929        // the uncond row at integration time) AND a DDIM scheduler (the only
930        // one whose alpha schedule we mirror). Other combinations fall back
931        // to standard CFG with a one-shot warn so misconfigurations surface.
932        let cfg_plus_schedule = if cfg_plus && use_cfg && matches!(sched, Scheduler::Ddim) {
933            Some(DdimAlphaSchedule::from_default(steps as usize))
934        } else {
935            if cfg_plus && !use_cfg {
936                tracing::warn!(
937                    guidance,
938                    "cfg_plus requested but cfg_scale ≈ 1.0 — falling back to standard step (no uncond available)"
939                );
940            } else if cfg_plus {
941                tracing::warn!(
942                    scheduler = ?sched,
943                    "cfg_plus requested but only DDIM is supported on SDXL/SD1.5 — falling back to standard step. Re-run with `--scheduler ddim` to enable CFG++."
944                );
945            }
946            None
947        };
948
949        let denoise_label = format!("Denoising ({} steps)", active_timesteps.len());
950        self.base.progress.stage_start(&denoise_label);
951        let denoise_start = Instant::now();
952
953        for (step_idx, &t) in active_timesteps.iter().enumerate() {
954            let step_start = std::time::Instant::now();
955            let latent_input = if use_cfg {
956                Tensor::cat(&[&*latents, &*latents], 0)?
957            } else {
958                latents.clone()
959            };
960
961            let latent_input = scheduler.scale_model_input(latent_input, t)?;
962            let noise_pred = unet.forward(&latent_input, t as f64, text_embeddings)?;
963
964            // Hold onto the raw uncond row when CFG++ is active so we can use
965            // it as the renoise direction below; standard path discards it
966            // after CFG blending.
967            let (noise_pred_blended, noise_pred_uncond_opt) = if use_cfg {
968                let chunks = noise_pred.chunk(2, 0)?;
969                let noise_pred_uncond = chunks[0].clone();
970                let noise_pred_cond = &chunks[1];
971                let blended =
972                    (&noise_pred_uncond + ((noise_pred_cond - &noise_pred_uncond)? * guidance)?)?;
973                (blended, Some(noise_pred_uncond))
974            } else {
975                (noise_pred, None)
976            };
977
978            *latents = match (cfg_plus_schedule.as_ref(), noise_pred_uncond_opt.as_ref()) {
979                (Some(ddim_sched), Some(eps_uncond)) => {
980                    ddim_sched.cfg_plus_step(&*latents, &noise_pred_blended, eps_uncond, t)?
981                }
982                _ => scheduler.step(&noise_pred_blended, t, &*latents)?,
983            };
984
985            if let Some(ctx) = inpaint_ctx {
986                let noised_original =
987                    scheduler.add_noise(&ctx.original_latents, ctx.noise.clone(), t)?;
988                *latents = crate::img2img::blend_inpaint_latents(&*latents, ctx, &noised_original)?;
989            }
990
991            self.base.progress.emit(ProgressEvent::DenoiseStep {
992                step: step_idx + 1,
993                total: active_timesteps.len(),
994                elapsed: step_start.elapsed(),
995            });
996        }
997
998        self.base
999            .progress
1000            .stage_done(&denoise_label, denoise_start.elapsed());
1001        Ok(())
1002    }
1003
1004    /// Prepare img2img latents: VAE encode source image, add noise at the appropriate timestep.
1005    /// Returns (noised_latents, start_step, encoded, noise).
1006    ///
1007    /// `dtype` is the engine-wide compute dtype (used for noise + denoise loop).
1008    /// `vae_dtype` may differ when `MOLD_VAE_DTYPE` forces fp32 — it controls
1009    /// the source-tensor decode and the VAE encode input precision; the
1010    /// returned encoded latents are cast back to `dtype` for the denoise.
1011    #[allow(clippy::too_many_arguments)]
1012    fn prepare_img2img_latents(
1013        &self,
1014        vae: &stable_diffusion::vae::AutoEncoderKL,
1015        source_bytes: &[u8],
1016        width: u32,
1017        height: u32,
1018        strength: f64,
1019        steps: u32,
1020        sched: Scheduler,
1021        seed: u64,
1022        device: &Device,
1023        dtype: DType,
1024        vae_dtype: DType,
1025    ) -> Result<(Tensor, usize, Tensor, Tensor)> {
1026        use crate::img_utils::{decode_source_image, NormalizeRange};
1027        let vae_scale = if self.is_turbo {
1028            VAE_SCALE_TURBO
1029        } else {
1030            VAE_SCALE_STANDARD
1031        };
1032        let cache_key = image_size_cache_key(source_bytes, width, height);
1033        let (encoded, cache_hit) = get_or_insert_cached_tensor(
1034            &self.source_latent_cache,
1035            cache_key,
1036            device,
1037            dtype,
1038            || {
1039                self.base
1040                    .progress
1041                    .stage_start("Encoding source image (VAE)");
1042                let encode_start = Instant::now();
1043
1044                let source_tensor = decode_source_image(
1045                    source_bytes,
1046                    width,
1047                    height,
1048                    NormalizeRange::MinusOneToOne,
1049                    device,
1050                    vae_dtype,
1051                )?;
1052                let encoded = vae.encode(&source_tensor)?;
1053                let encoded = (encoded.mode()? * vae_scale)?;
1054                // VAE may have been loaded at fp32 (banding fix); cast the
1055                // encoded latents back to engine dtype so the rest of the
1056                // denoise loop stays at its natural precision.
1057                let encoded = encoded.to_dtype(dtype)?;
1058
1059                self.base
1060                    .progress
1061                    .stage_done("Encoding source image (VAE)", encode_start.elapsed());
1062                Ok(encoded)
1063            },
1064        )?;
1065        if cache_hit {
1066            self.base.progress.cache_hit("source image latents");
1067        }
1068
1069        let start_step = crate::img2img::img2img_start_index(steps as usize, strength);
1070
1071        let scheduler = crate::scheduler::build_scheduler(
1072            sched,
1073            steps as usize,
1074            PredictionType::Epsilon,
1075            self.is_turbo,
1076        )?;
1077        let timesteps = scheduler.timesteps().to_vec();
1078
1079        let latent_h = height as usize / 8;
1080        let latent_w = width as usize / 8;
1081        let noise =
1082            crate::engine::seeded_randn(seed, &[1, 4, latent_h, latent_w], device, DType::F32)?;
1083        let noise = noise.to_dtype(dtype)?;
1084
1085        let noised = if start_step < timesteps.len() {
1086            scheduler.add_noise(&encoded, noise.clone(), timesteps[start_step])?
1087        } else {
1088            encoded.clone()
1089        };
1090
1091        tracing::info!(
1092            start_step,
1093            total_steps = steps,
1094            strength,
1095            "img2img: starting from step {start_step}"
1096        );
1097
1098        Ok((noised, start_step, encoded, noise))
1099    }
1100
1101    /// Encode prompt with both CLIP encoders.
1102    #[allow(clippy::too_many_arguments)]
1103    fn encode_prompt(
1104        &self,
1105        clip_l: &stable_diffusion::clip::ClipTextTransformer,
1106        clip_g: &stable_diffusion::clip::ClipTextTransformer,
1107        tokenizer_l: &tokenizers::Tokenizer,
1108        tokenizer_g: &tokenizers::Tokenizer,
1109        prompt: &str,
1110        negative_prompt: &str,
1111        max_len: usize,
1112        device: &Device,
1113        clip_device: &Device,
1114        dtype: DType,
1115        guidance: f64,
1116    ) -> Result<Tensor> {
1117        // SDXL caches the **concatenated** `(uncond, cond)` tensor when CFG is
1118        // active, so the cache key must include the negative prompt and the
1119        // guidance scale. Keying on the positive prompt + guidance alone
1120        // returned a stale uncond branch when the user changed only the
1121        // negative prompt — silent wrong output.
1122        let cache_key = cfg_prompt_cache_key(prompt, negative_prompt, guidance);
1123        let (text_embeddings, cache_hit) =
1124            get_or_insert_cached_tensor(&self.prompt_cache, cache_key, device, dtype, || {
1125                let use_cfg = cfg_active(guidance);
1126
1127                self.base.progress.stage_start("Encoding prompt (CLIP-L)");
1128                let encode_l_start = Instant::now();
1129                let tokens_l = Self::tokenize(tokenizer_l, prompt, max_len, clip_device)?;
1130                let text_emb_l = clip_l.forward(&tokens_l)?;
1131                self.base
1132                    .progress
1133                    .stage_done("Encoding prompt (CLIP-L)", encode_l_start.elapsed());
1134
1135                self.base.progress.stage_start("Encoding prompt (CLIP-G)");
1136                let encode_g_start = Instant::now();
1137                let tokens_g = Self::tokenize(tokenizer_g, prompt, max_len, clip_device)?;
1138                let text_emb_g = clip_g.forward(&tokens_g)?;
1139                self.base
1140                    .progress
1141                    .stage_done("Encoding prompt (CLIP-G)", encode_g_start.elapsed());
1142
1143                let text_embeddings = Tensor::cat(&[&text_emb_l, &text_emb_g], D::Minus1)?;
1144
1145                let text_embeddings = if use_cfg {
1146                    let uncond_tokens_l =
1147                        Self::tokenize(tokenizer_l, negative_prompt, max_len, clip_device)?;
1148                    let uncond_emb_l = clip_l.forward(&uncond_tokens_l)?;
1149                    let uncond_tokens_g =
1150                        Self::tokenize(tokenizer_g, negative_prompt, max_len, clip_device)?;
1151                    let uncond_emb_g = clip_g.forward(&uncond_tokens_g)?;
1152                    let uncond_embeddings =
1153                        Tensor::cat(&[&uncond_emb_l, &uncond_emb_g], D::Minus1)?;
1154                    Tensor::cat(&[&uncond_embeddings, &text_embeddings], 0)?
1155                } else {
1156                    text_embeddings
1157                };
1158
1159                let text_embeddings = text_embeddings.to_device(device)?;
1160                Ok(text_embeddings.to_dtype(dtype)?)
1161            })?;
1162        if cache_hit {
1163            self.base.progress.cache_hit("prompt conditioning");
1164            return Ok(text_embeddings);
1165        }
1166        Ok(text_embeddings)
1167    }
1168
1169    fn cached_mask(
1170        &self,
1171        mask_bytes: &[u8],
1172        latent_h: usize,
1173        latent_w: usize,
1174        device: &Device,
1175        dtype: DType,
1176    ) -> Result<Tensor> {
1177        let key = latent_size_cache_key(mask_bytes, latent_h, latent_w);
1178        let (mask, cache_hit) =
1179            get_or_insert_cached_tensor(&self.mask_cache, key, device, dtype, || {
1180                crate::img_utils::decode_mask_image(mask_bytes, latent_h, latent_w, device, dtype)
1181            })?;
1182        if cache_hit {
1183            self.base.progress.cache_hit("inpaint mask");
1184            return Ok(mask);
1185        }
1186        Ok(mask)
1187    }
1188
1189    /// Generate an image using sequential loading strategy.
1190    ///
1191    /// Loads components one at a time and drops them when done:
1192    /// 1. Load CLIP-L → encode → drop CLIP-L
1193    /// 2. Load CLIP-G → encode → drop CLIP-G
1194    /// 3. Load UNet → denoise → drop UNet
1195    /// 4. Load VAE → decode → drop VAE
1196    fn generate_sequential(&mut self, req: &GenerateRequest) -> Result<GenerateResponse> {
1197        let (clip_encoder, clip_tokenizer, clip_encoder_2, clip_tokenizer_2) =
1198            self.validate_paths()?;
1199
1200        // Check memory budget
1201        if let Some(warning) = check_memory_budget(&self.base.paths, LoadStrategy::Sequential) {
1202            self.base.progress.info(&warning);
1203        }
1204
1205        let device = crate::device::create_device(self.base.gpu_ordinal, &self.base.progress)?;
1206        let dtype = if crate::device::is_gpu(&device) {
1207            DType::F16
1208        } else {
1209            DType::F32
1210        };
1211
1212        let sd_config = self.sd_config();
1213        let max_len = sd_config.clip.max_position_embeddings;
1214
1215        let start = Instant::now();
1216        let seed = req.seed.unwrap_or_else(rand_seed);
1217
1218        let width = req.width as usize;
1219        let height = req.height as usize;
1220        let guidance = req.guidance;
1221
1222        tracing::info!(
1223            prompt = %req.prompt,
1224            seed, width, height,
1225            steps = req.steps,
1226            guidance,
1227            "starting sequential SDXL generation"
1228        );
1229
1230        self.base
1231            .progress
1232            .info("Using sequential loading (load-use-drop) to minimize peak memory");
1233
1234        // --- Phase 1: Encode prompt (check cache first to skip encoder load) ---
1235        let neg = req.negative_prompt.as_deref().unwrap_or("");
1236        let cache_key = cfg_prompt_cache_key(&req.prompt, neg, guidance);
1237        let text_embeddings = if let Some(tensor) =
1238            restore_cached_tensor(&self.prompt_cache, &cache_key, &device, dtype)?
1239        {
1240            self.base.progress.cache_hit("prompt conditioning");
1241            tensor
1242        } else {
1243            if let Some(status) = memory_status_string() {
1244                self.base.progress.info(&status);
1245            }
1246
1247            let tokenizer_l = self.load_clip_tokenizer(&clip_tokenizer, "CLIP-L")?;
1248            let tokenizer_g = self.load_clip_tokenizer(&clip_tokenizer_2, "CLIP-G")?;
1249
1250            let tier1 = self
1251                .pending_placement
1252                .as_ref()
1253                .map(|p| p.text_encoders)
1254                .unwrap_or_default();
1255            let clip_device = crate::device::resolve_device(Some(tier1), || Ok(device.clone()))?;
1256
1257            // Branch on single_file_path so cv:<id>-style Civitai checkpoints
1258            // (A1111 naming) go through `SingleFileBackend` just like eager
1259            // mode does. Without this, candle's diffusers-keyed
1260            // `build_clip_transformer` errors with
1261            // "cannot find tensor text_model.embeddings.token_embedding.weight".
1262            let (clip_l, clip_g) =
1263                if let Some(single_file) = self.single_file_path.clone() {
1264                    let remap = Self::load_sdxl_remap(&single_file)?;
1265
1266                    self.base
1267                        .progress
1268                        .stage_start("Loading CLIP-L (single-file)");
1269                    let clip_l_start = Instant::now();
1270                    let clip_l = Self::build_clip_l_single_file(
1271                        &single_file,
1272                        &remap,
1273                        &sd_config.clip,
1274                        &clip_device,
1275                    )?;
1276                    self.base
1277                        .progress
1278                        .stage_done("Loading CLIP-L (single-file)", clip_l_start.elapsed());
1279
1280                    self.base
1281                        .progress
1282                        .stage_start("Loading CLIP-G (single-file)");
1283                    let clip_g_start = Instant::now();
1284                    let clip2_config = sd_config.clip2.as_ref().ok_or_else(|| {
1285                        anyhow::anyhow!("SDXL config missing clip2 configuration")
1286                    })?;
1287                    let clip_g = Self::build_clip_g_single_file(
1288                        &single_file,
1289                        &remap,
1290                        clip2_config,
1291                        &clip_device,
1292                    )?;
1293                    self.base
1294                        .progress
1295                        .stage_done("Loading CLIP-G (single-file)", clip_g_start.elapsed());
1296
1297                    (clip_l, clip_g)
1298                } else {
1299                    self.base.progress.stage_start("Loading CLIP-L encoder");
1300                    let clip_l_start = Instant::now();
1301                    let clip_l = stable_diffusion::build_clip_transformer(
1302                        &sd_config.clip,
1303                        &clip_encoder,
1304                        &clip_device,
1305                        DType::F32,
1306                    )?;
1307                    self.base
1308                        .progress
1309                        .stage_done("Loading CLIP-L encoder", clip_l_start.elapsed());
1310
1311                    self.base.progress.stage_start("Loading CLIP-G encoder");
1312                    let clip_g_start = Instant::now();
1313                    let clip2_config = sd_config.clip2.as_ref().ok_or_else(|| {
1314                        anyhow::anyhow!("SDXL config missing clip2 configuration")
1315                    })?;
1316                    let clip_g = stable_diffusion::build_clip_transformer(
1317                        clip2_config,
1318                        &clip_encoder_2,
1319                        &clip_device,
1320                        DType::F32,
1321                    )?;
1322                    self.base
1323                        .progress
1324                        .stage_done("Loading CLIP-G encoder", clip_g_start.elapsed());
1325
1326                    (clip_l, clip_g)
1327                };
1328
1329            let text_embeddings = self.encode_prompt(
1330                &clip_l,
1331                &clip_g,
1332                &tokenizer_l,
1333                &tokenizer_g,
1334                &req.prompt,
1335                neg,
1336                max_len,
1337                &device,
1338                &clip_device,
1339                dtype,
1340                guidance,
1341            )?;
1342
1343            drop(clip_l);
1344            drop(clip_g);
1345            self.base.progress.info("Freed CLIP-L and CLIP-G encoders");
1346            tracing::info!("CLIP encoders dropped (sequential mode)");
1347
1348            text_embeddings
1349        };
1350
1351        // --- Phase 2: Load UNet and denoise ---
1352        let unet_size = std::fs::metadata(&self.base.paths.transformer)
1353            .map(|m| m.len())
1354            .unwrap_or(0);
1355        // SDXL runs CFG by default → batch=2 unless guidance ≈ 1 (LCM/Turbo).
1356        let unet_batch = if req.guidance > 1.0 { 2 } else { 1 };
1357        let unet_activation_budget = crate::device::activation_bytes(
1358            req.width,
1359            req.height,
1360            unet_batch,
1361            crate::device::dtype_bytes(dtype),
1362            crate::device::ActivationFamily::SdxlUnet,
1363        );
1364        preflight_memory_check("UNet", unet_size, unet_activation_budget)?;
1365        if let Some(status) = memory_status_string() {
1366            self.base.progress.info(&status);
1367        }
1368
1369        self.base.progress.stage_start("Loading UNet (GPU)");
1370        let unet_start = Instant::now();
1371        let unet = self.build_unet_for_strategy(&sd_config, &device, dtype)?;
1372        self.base
1373            .progress
1374            .stage_done("Loading UNet (GPU)", unet_start.elapsed());
1375
1376        let sched = req.scheduler.unwrap_or(self.scheduler);
1377        let is_img2img = req.source_image.is_some();
1378
1379        let (mut latents, start_step, inpaint_ctx) = if let Some(ref source_bytes) =
1380            req.source_image
1381        {
1382            self.base
1383                .progress
1384                .info("img2img mode: encoding source image before denoising");
1385
1386            self.base.progress.stage_start("Loading VAE (GPU)");
1387            let vae_start_t = Instant::now();
1388            let vae_dtype = resolve_sdxl_vae_dtype(dtype, self.single_file_path.is_some());
1389            let vae = self.build_vae_for_strategy(&sd_config, &device, vae_dtype)?;
1390            self.base
1391                .progress
1392                .stage_done("Loading VAE (GPU)", vae_start_t.elapsed());
1393
1394            let (latents, start_step, encoded, noise) = self.prepare_img2img_latents(
1395                &vae,
1396                source_bytes,
1397                req.width,
1398                req.height,
1399                req.strength,
1400                req.steps,
1401                sched,
1402                seed,
1403                &device,
1404                dtype,
1405                vae_dtype,
1406            )?;
1407
1408            let inpaint_ctx = if let Some(ref mask_bytes) = req.mask_image {
1409                let mask = self.cached_mask(mask_bytes, height / 8, width / 8, &device, dtype)?;
1410                Some(crate::img_utils::InpaintContext {
1411                    original_latents: encoded,
1412                    mask,
1413                    noise,
1414                })
1415            } else {
1416                None
1417            };
1418
1419            drop(vae);
1420            self.base
1421                .progress
1422                .info("Freed VAE (will reload for decode)");
1423            device.synchronize()?;
1424
1425            (latents, start_step, inpaint_ctx)
1426        } else {
1427            let latent_h = height / 8;
1428            let latent_w = width / 8;
1429            let init_scheduler = crate::scheduler::build_scheduler(
1430                sched,
1431                req.steps as usize,
1432                PredictionType::Epsilon,
1433                self.is_turbo,
1434            )?;
1435            let init_noise_sigma = init_scheduler.init_noise_sigma();
1436            drop(init_scheduler);
1437            let latents = (crate::engine::seeded_randn(
1438                seed,
1439                &[1, 4, latent_h, latent_w],
1440                &device,
1441                DType::F32,
1442            )? * init_noise_sigma)?;
1443            (latents.to_dtype(dtype)?, 0, None)
1444        };
1445
1446        self.denoise_loop(
1447            &unet,
1448            &text_embeddings,
1449            sched,
1450            &mut latents,
1451            guidance,
1452            resolve_cfg_plus(req),
1453            req.steps,
1454            start_step,
1455            inpaint_ctx.as_ref(),
1456        )?;
1457
1458        drop(inpaint_ctx);
1459        drop(unet);
1460        drop(text_embeddings);
1461        device.synchronize()?;
1462        self.base.progress.info("Freed UNet");
1463        tracing::info!("UNet dropped (sequential mode)");
1464
1465        // --- Phase 3: Load VAE and decode ---
1466        let vae_load_label = if is_img2img {
1467            "Reloading VAE (GPU)"
1468        } else {
1469            "Loading VAE (GPU)"
1470        };
1471        self.base.progress.stage_start(vae_load_label);
1472        let vae_start = Instant::now();
1473        let vae_dtype = resolve_sdxl_vae_dtype(dtype, self.single_file_path.is_some());
1474        let vae = self.build_vae_for_strategy(&sd_config, &device, vae_dtype)?;
1475        self.base
1476            .progress
1477            .stage_done(vae_load_label, vae_start.elapsed());
1478
1479        self.base.progress.stage_start("VAE decode");
1480        let vae_decode_start = Instant::now();
1481
1482        let vae_scale = if self.is_turbo {
1483            VAE_SCALE_TURBO
1484        } else {
1485            VAE_SCALE_STANDARD
1486        };
1487        let latents = (latents / vae_scale)?;
1488        let latents_for_vae = latents.to_dtype(vae_dtype)?;
1489        let device_for_sync = device.clone();
1490        let img = crate::vae_tiling::decode_with_oom_fallback(
1491            &latents_for_vae,
1492            |t| vae.decode(t).map_err(Into::into),
1493            || {
1494                if let Err(e) = device_for_sync.synchronize() {
1495                    tracing::warn!(
1496                        "SDXL (sequential) device.synchronize() after VAE OOM failed: {e}"
1497                    );
1498                }
1499            },
1500        )?;
1501
1502        let img = ((img / 2.)? + 0.5)?.clamp(0f32, 1f32)?;
1503        let img = (img * 255.)?.to_dtype(DType::U8)?;
1504        let img = img.squeeze(0)?;
1505
1506        self.base
1507            .progress
1508            .stage_done("VAE decode", vae_decode_start.elapsed());
1509
1510        // VAE dropped here
1511        let output_metadata = build_output_metadata(req, seed, Some(sched));
1512        let image_bytes = encode_image(
1513            &img,
1514            req.resolved_output_format(),
1515            req.width,
1516            req.height,
1517            output_metadata.as_ref(),
1518        )?;
1519
1520        let generation_time_ms = start.elapsed().as_millis() as u64;
1521        tracing::info!(
1522            generation_time_ms,
1523            seed,
1524            "sequential SDXL generation complete"
1525        );
1526
1527        Ok(GenerateResponse {
1528            images: vec![ImageData {
1529                data: image_bytes,
1530                format: req.resolved_output_format(),
1531                width: req.width,
1532                height: req.height,
1533                index: 0,
1534            }],
1535            generation_time_ms,
1536            model: req.model.clone(),
1537            seed_used: seed,
1538            video: None,
1539            gpu: None,
1540        })
1541    }
1542}
1543
1544impl SDXLEngine {
1545    fn generate_inner(&mut self, req: &GenerateRequest) -> Result<GenerateResponse> {
1546        // Sequential mode: load-use-drop each component
1547        if self.base.load_strategy == LoadStrategy::Sequential {
1548            return self.generate_sequential(req);
1549        }
1550
1551        // Eager mode: if the requested LoRA stack differs from what's merged
1552        // into the loaded UNet, drop it now so `reload_unet_if_needed` rebuilds
1553        // it with the new wrapper.
1554        let requested_stack = lora_stack_fingerprint(&self.pending_loras);
1555        if requested_stack != self.active_lora_fingerprint {
1556            if let Some(loaded) = self.base.loaded.as_mut() {
1557                if loaded.unet.is_some() {
1558                    loaded.unet = None;
1559                    loaded.device.synchronize()?;
1560                    tracing::info!("SDXL UNet dropped (LoRA stack changed)");
1561                }
1562            }
1563            self.active_lora_fingerprint = requested_stack;
1564        }
1565
1566        // Eager mode: reload UNet if dropped after previous VAE decode
1567        self.reload_unet_if_needed()?;
1568
1569        let loaded = self
1570            .base
1571            .loaded
1572            .as_ref()
1573            .ok_or_else(|| anyhow::anyhow!("model not loaded — call load() first"))?;
1574
1575        let start = Instant::now();
1576        let seed = req.seed.unwrap_or_else(rand_seed);
1577
1578        let width = req.width as usize;
1579        let height = req.height as usize;
1580        let guidance = req.guidance;
1581
1582        tracing::info!(
1583            prompt = %req.prompt,
1584            seed, width, height,
1585            steps = req.steps,
1586            guidance,
1587            scheduler = %self.scheduler,
1588            "starting SDXL generation"
1589        );
1590
1591        // 1. Encode prompt with both CLIP encoders
1592        let max_len = loaded.sd_config.clip.max_position_embeddings;
1593        let neg = req.negative_prompt.as_deref().unwrap_or("");
1594        let text_embeddings = self.encode_prompt(
1595            &loaded.clip_l,
1596            &loaded.clip_g,
1597            &loaded.tokenizer_l,
1598            &loaded.tokenizer_g,
1599            &req.prompt,
1600            neg,
1601            max_len,
1602            &loaded.device,
1603            &loaded.clip_device,
1604            loaded.dtype,
1605            guidance,
1606        )?;
1607
1608        // 3. Build scheduler and create initial latents
1609        let sched = req.scheduler.unwrap_or(self.scheduler);
1610
1611        let (mut latents, start_step, inpaint_ctx) =
1612            if let Some(ref source_bytes) = req.source_image {
1613                let (latents, start_step, encoded, noise) = self.prepare_img2img_latents(
1614                    &loaded.vae,
1615                    source_bytes,
1616                    req.width,
1617                    req.height,
1618                    req.strength,
1619                    req.steps,
1620                    sched,
1621                    seed,
1622                    &loaded.device,
1623                    loaded.dtype,
1624                    loaded.vae_dtype,
1625                )?;
1626                let inpaint_ctx = if let Some(ref mask_bytes) = req.mask_image {
1627                    let mask = self.cached_mask(
1628                        mask_bytes,
1629                        height / 8,
1630                        width / 8,
1631                        &loaded.device,
1632                        loaded.dtype,
1633                    )?;
1634                    Some(crate::img_utils::InpaintContext {
1635                        original_latents: encoded,
1636                        mask,
1637                        noise,
1638                    })
1639                } else {
1640                    None
1641                };
1642                (latents, start_step, inpaint_ctx)
1643            } else {
1644                let latent_h = height / 8;
1645                let latent_w = width / 8;
1646                let init_scheduler = crate::scheduler::build_scheduler(
1647                    sched,
1648                    req.steps as usize,
1649                    PredictionType::Epsilon,
1650                    self.is_turbo,
1651                )?;
1652                let init_noise_sigma = init_scheduler.init_noise_sigma();
1653                drop(init_scheduler);
1654                let latents = (crate::engine::seeded_randn(
1655                    seed,
1656                    &[1, 4, latent_h, latent_w],
1657                    &loaded.device,
1658                    DType::F32,
1659                )? * init_noise_sigma)?;
1660                (latents.to_dtype(loaded.dtype)?, 0, None)
1661            };
1662
1663        // 5. Denoising loop
1664        let unet = loaded
1665            .unet
1666            .as_ref()
1667            .ok_or_else(|| anyhow::anyhow!("UNet not loaded"))?;
1668        self.denoise_loop(
1669            unet,
1670            &text_embeddings,
1671            sched,
1672            &mut latents,
1673            guidance,
1674            resolve_cfg_plus(req),
1675            req.steps,
1676            start_step,
1677            inpaint_ctx.as_ref(),
1678        )?;
1679
1680        // Drop UNet before VAE decode to free VRAM for conv2d intermediates.
1681        drop(inpaint_ctx);
1682        let _ = loaded;
1683        let loaded = self.base.loaded.as_mut().unwrap();
1684        loaded.unet = None;
1685        loaded.device.synchronize()?;
1686        tracing::info!("UNet dropped to free VRAM for VAE decode");
1687        let _ = loaded;
1688        let loaded = self.base.loaded.as_ref().unwrap();
1689
1690        // 6. VAE decode
1691        self.base.progress.stage_start("VAE decode");
1692        let vae_start = Instant::now();
1693
1694        let vae_scale = if self.is_turbo {
1695            VAE_SCALE_TURBO
1696        } else {
1697            VAE_SCALE_STANDARD
1698        };
1699        let latents = (latents / vae_scale)?;
1700        let latents_for_vae = latents.to_dtype(loaded.vae_dtype)?;
1701        let vae = &loaded.vae;
1702        let device_for_sync = loaded.device.clone();
1703        let img = crate::vae_tiling::decode_with_oom_fallback(
1704            &latents_for_vae,
1705            |t| vae.decode(t).map_err(Into::into),
1706            || {
1707                if let Err(e) = device_for_sync.synchronize() {
1708                    tracing::warn!(
1709                        "SDXL (parallel) device.synchronize() after VAE OOM failed: {e}"
1710                    );
1711                }
1712            },
1713        )?;
1714
1715        // 7. Post-process: [1, 3, H, W] → clamp → u8
1716        let img = ((img / 2.)? + 0.5)?.clamp(0f32, 1f32)?;
1717        let img = (img * 255.)?.to_dtype(DType::U8)?;
1718        let img = img.squeeze(0)?; // [3, H, W]
1719
1720        self.base
1721            .progress
1722            .stage_done("VAE decode", vae_start.elapsed());
1723
1724        // 8. Encode to image format
1725        let output_metadata = build_output_metadata(req, seed, Some(sched));
1726        let image_bytes = encode_image(
1727            &img,
1728            req.resolved_output_format(),
1729            req.width,
1730            req.height,
1731            output_metadata.as_ref(),
1732        )?;
1733
1734        let generation_time_ms = start.elapsed().as_millis() as u64;
1735        tracing::info!(generation_time_ms, seed, "SDXL generation complete");
1736
1737        Ok(GenerateResponse {
1738            images: vec![ImageData {
1739                data: image_bytes,
1740                format: req.resolved_output_format(),
1741                width: req.width,
1742                height: req.height,
1743                index: 0,
1744            }],
1745            generation_time_ms,
1746            model: req.model.clone(),
1747            seed_used: seed,
1748            video: None,
1749            gpu: None,
1750        })
1751    }
1752}
1753
1754impl InferenceEngine for SDXLEngine {
1755    fn generate(&mut self, req: &GenerateRequest) -> Result<GenerateResponse> {
1756        self.pending_placement = req.placement.clone();
1757        self.pending_loras = super::lora::effective_sdxl_loras(req);
1758        let result = self.generate_inner(req);
1759        self.pending_placement = None;
1760        self.pending_loras.clear();
1761        result
1762    }
1763
1764    fn model_name(&self) -> &str {
1765        self.base.model_name()
1766    }
1767
1768    fn is_loaded(&self) -> bool {
1769        // Sequential mode is always "ready" — it loads on demand
1770        self.base.is_loaded()
1771    }
1772
1773    fn load(&mut self) -> Result<()> {
1774        SDXLEngine::load(self)
1775    }
1776
1777    fn unload(&mut self) {
1778        self.base.unload();
1779        clear_cache(&self.prompt_cache);
1780        clear_cache(&self.source_latent_cache);
1781        clear_cache(&self.mask_cache);
1782        self.active_lora_fingerprint.clear();
1783    }
1784
1785    fn set_on_progress(&mut self, callback: ProgressCallback) {
1786        self.base.set_on_progress(callback);
1787    }
1788
1789    fn clear_on_progress(&mut self) {
1790        self.base.clear_on_progress();
1791    }
1792
1793    fn model_paths(&self) -> Option<&mold_core::ModelPaths> {
1794        Some(&self.base.paths)
1795    }
1796}
1797
1798#[cfg(test)]
1799mod tests {
1800    use super::*;
1801    use crate::engine::InferenceEngine;
1802    use crate::shared_pool::SharedPool;
1803    use safetensors::tensor::{serialize_to_file, Dtype as SafeDtype, TensorView};
1804    use std::collections::HashMap;
1805    use std::sync::{Arc, Mutex};
1806    use tokenizers::models::bpe::BPE;
1807
1808    /// Synthesise a minimal SDXL-shaped single-file safetensors with one
1809    /// representative key per component bucket. Tensor data is one zero
1810    /// F32 — `loader::single_file::load` is header-only and the
1811    /// constructor doesn't materialise weights.
1812    fn synth_sdxl_single_file(name: &str) -> PathBuf {
1813        let path = std::env::temp_dir().join(format!(
1814            "mold-sdxl-from-sf-{}-{}-{}.safetensors",
1815            name,
1816            std::process::id(),
1817            std::time::SystemTime::now()
1818                .duration_since(std::time::UNIX_EPOCH)
1819                .unwrap()
1820                .as_nanos(),
1821        ));
1822
1823        let keys: &[&str] = &[
1824            // UNet
1825            "model.diffusion_model.input_blocks.0.0.weight",
1826            "model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight",
1827            // VAE
1828            "first_stage_model.encoder.down.0.block.0.norm1.weight",
1829            "first_stage_model.quant_conv.weight",
1830            // CLIP-L
1831            "conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.q_proj.weight",
1832            "conditioner.embedders.0.transformer.text_model.final_layer_norm.weight",
1833            // CLIP-G — including the fused QKV slab so the constructor
1834            // exercises the FusedSlice branch in build_sdxl_remap.
1835            "conditioner.embedders.1.model.transformer.resblocks.0.attn.in_proj_weight",
1836            "conditioner.embedders.1.model.text_projection",
1837        ];
1838
1839        let f32_zero = 0.0f32.to_le_bytes().to_vec();
1840        let buffers: Vec<Vec<u8>> = keys.iter().map(|_| f32_zero.clone()).collect();
1841        let mut tensors: HashMap<String, TensorView<'_>> = HashMap::new();
1842        for (key, buf) in keys.iter().zip(buffers.iter()) {
1843            tensors.insert(
1844                (*key).to_string(),
1845                TensorView::new(SafeDtype::F32, vec![1], buf).unwrap(),
1846            );
1847        }
1848        serialize_to_file(&tensors, &None, &path).unwrap();
1849        path
1850    }
1851
1852    #[test]
1853    fn from_single_file_constructs_for_synthetic_sdxl_checkpoint() {
1854        let single_file = synth_sdxl_single_file("ok");
1855        // Companion CLIP-L / CLIP-G tokenizers are pulled separately
1856        // (phase 2.7) and are not validated by the constructor — their
1857        // paths are just stored.
1858        let clip_l_tok = std::env::temp_dir().join("mold-sdxl-clip-l-stub.json");
1859        let clip_g_tok = std::env::temp_dir().join("mold-sdxl-clip-g-stub.json");
1860
1861        let engine = SDXLEngine::from_single_file(
1862            "juggernaut-xl-v9".to_string(),
1863            single_file.clone(),
1864            clip_l_tok,
1865            clip_g_tok,
1866            Scheduler::default(),
1867            false,
1868            LoadStrategy::Eager,
1869            0,
1870            None,
1871        )
1872        .expect("constructor must accept a valid SDXL single-file layout");
1873
1874        assert_eq!(engine.model_name(), "juggernaut-xl-v9");
1875        assert_eq!(
1876            engine.single_file_path.as_deref(),
1877            Some(single_file.as_path()),
1878            "single-file path must be stashed for the future load() branch",
1879        );
1880        assert!(
1881            !engine.is_loaded(),
1882            "constructor must not eagerly materialise model weights",
1883        );
1884
1885        let _ = std::fs::remove_file(single_file);
1886    }
1887
1888    #[test]
1889    fn sdxl_loads_clip_tokenizers_through_shared_pool() {
1890        let dir = tempfile::tempdir().unwrap();
1891        let clip_l_tokenizer = dir.path().join("clip-l-tokenizer.json");
1892        let clip_g_tokenizer = dir.path().join("clip-g-tokenizer.json");
1893        tokenizers::Tokenizer::new(BPE::default())
1894            .save(&clip_l_tokenizer, false)
1895            .unwrap();
1896        tokenizers::Tokenizer::new(BPE::default())
1897            .save(&clip_g_tokenizer, false)
1898            .unwrap();
1899        let weights_path = dir.path().join("weights.safetensors");
1900        std::fs::write(&weights_path, b"stub").unwrap();
1901
1902        let shared_pool = Arc::new(Mutex::new(SharedPool::new()));
1903        let pooled_l = shared_pool
1904            .lock()
1905            .unwrap()
1906            .load_tokenizer(&clip_l_tokenizer)
1907            .unwrap();
1908        let pooled_g = shared_pool
1909            .lock()
1910            .unwrap()
1911            .load_tokenizer(&clip_g_tokenizer)
1912            .unwrap();
1913
1914        let paths = ModelPaths {
1915            transformer: weights_path.clone(),
1916            transformer_shards: Vec::new(),
1917            vae: weights_path.clone(),
1918            spatial_upscaler: None,
1919            temporal_upscaler: None,
1920            distilled_lora: None,
1921            t5_encoder: None,
1922            clip_encoder: Some(weights_path.clone()),
1923            t5_tokenizer: None,
1924            clip_tokenizer: Some(clip_l_tokenizer.clone()),
1925            clip_encoder_2: Some(weights_path),
1926            clip_tokenizer_2: Some(clip_g_tokenizer.clone()),
1927            text_encoder_files: Vec::new(),
1928            text_tokenizer: None,
1929            decoder: None,
1930        };
1931        let engine = SDXLEngine::new(
1932            "sdxl-test".to_string(),
1933            paths,
1934            Scheduler::default(),
1935            false,
1936            LoadStrategy::Eager,
1937            0,
1938            Some(shared_pool),
1939        );
1940
1941        let loaded_l = engine
1942            .load_clip_tokenizer(&clip_l_tokenizer, "CLIP-L")
1943            .unwrap();
1944        let loaded_g = engine
1945            .load_clip_tokenizer(&clip_g_tokenizer, "CLIP-G")
1946            .unwrap();
1947
1948        assert!(Arc::ptr_eq(&pooled_l, &loaded_l));
1949        assert!(Arc::ptr_eq(&pooled_g, &loaded_g));
1950    }
1951
1952    #[test]
1953    fn sdxl_loads_vae_tensors_through_shared_pool() {
1954        let dir = tempfile::tempdir().unwrap();
1955        let vae_path = dir.path().join("vae.safetensors");
1956        let weight = 1.0f32.to_le_bytes();
1957        let mut tensors = HashMap::new();
1958        tensors.insert(
1959            "encoder.conv_in.weight".to_string(),
1960            TensorView::new(SafeDtype::F32, vec![1], &weight).unwrap(),
1961        );
1962        serialize_to_file(&tensors, &None, &vae_path).unwrap();
1963
1964        let shared_pool = Arc::new(Mutex::new(SharedPool::new()));
1965        let pooled = shared_pool
1966            .lock()
1967            .unwrap()
1968            .load_safetensors_cpu_tensors(std::slice::from_ref(&vae_path))
1969            .unwrap()
1970            .unwrap();
1971
1972        let paths = ModelPaths {
1973            transformer: dir.path().join("unet.safetensors"),
1974            transformer_shards: Vec::new(),
1975            vae: vae_path.clone(),
1976            spatial_upscaler: None,
1977            temporal_upscaler: None,
1978            distilled_lora: None,
1979            t5_encoder: None,
1980            clip_encoder: Some(dir.path().join("clip-l.safetensors")),
1981            t5_tokenizer: None,
1982            clip_tokenizer: Some(dir.path().join("clip-l-tokenizer.json")),
1983            clip_encoder_2: Some(dir.path().join("clip-g.safetensors")),
1984            clip_tokenizer_2: Some(dir.path().join("clip-g-tokenizer.json")),
1985            text_encoder_files: Vec::new(),
1986            text_tokenizer: None,
1987            decoder: None,
1988        };
1989        let engine = SDXLEngine::new(
1990            "sdxl-test".to_string(),
1991            paths,
1992            Scheduler::default(),
1993            false,
1994            LoadStrategy::Eager,
1995            0,
1996            Some(shared_pool),
1997        );
1998
1999        let loaded = engine.load_vae_cpu_tensors().unwrap().unwrap();
2000
2001        assert!(Arc::ptr_eq(&pooled, &loaded));
2002    }
2003
2004    #[test]
2005    fn from_single_file_rejects_missing_file() {
2006        let bogus = std::env::temp_dir().join(format!(
2007            "mold-sdxl-from-sf-missing-{}-{}.safetensors",
2008            std::process::id(),
2009            std::time::SystemTime::now()
2010                .duration_since(std::time::UNIX_EPOCH)
2011                .unwrap()
2012                .as_nanos(),
2013        ));
2014
2015        let result = SDXLEngine::from_single_file(
2016            "missing".to_string(),
2017            bogus,
2018            std::env::temp_dir().join("mold-sdxl-clip-l-stub.json"),
2019            std::env::temp_dir().join("mold-sdxl-clip-g-stub.json"),
2020            Scheduler::default(),
2021            false,
2022            LoadStrategy::Eager,
2023            0,
2024            None,
2025        );
2026
2027        assert!(
2028            result.is_err(),
2029            "constructor must surface a missing-file error before deeper parsing",
2030        );
2031    }
2032
2033    #[test]
2034    fn load_branches_to_single_file_path_and_invokes_candle_constructors() {
2035        // Phase 2.8.5 smoke (TDD parity with SD15): single-file SDXL
2036        // `load()` must dispatch to the single-file branch and hand four
2037        // `VarBuilder::from_backend(SingleFileBackend)` instances to
2038        // candle's per-component constructors. The synthetic checkpoint
2039        // doesn't carry the full SDXL tensor set (UNet, VAE, CLIP-L,
2040        // CLIP-G with fused QKV), so the load surfaces an error sourced
2041        // from the single-file layer rather than from the diffusers
2042        // fallback. Real-shape construction is exercised by 2.10's
2043        // <gpu-host> UAT against a real Pony / Juggernaut XL checkpoint.
2044        let single_file = synth_sdxl_single_file("load-branch");
2045        let make_stub = |label: &str| -> PathBuf {
2046            let path = std::env::temp_dir().join(format!(
2047                "mold-sdxl-{}-stub-{}-{}.json",
2048                label,
2049                std::process::id(),
2050                std::time::SystemTime::now()
2051                    .duration_since(std::time::UNIX_EPOCH)
2052                    .unwrap()
2053                    .as_nanos(),
2054            ));
2055            std::fs::write(&path, b"").unwrap();
2056            path
2057        };
2058        let clip_l_tok = make_stub("clip-l");
2059        let clip_g_tok = make_stub("clip-g");
2060
2061        let mut engine = SDXLEngine::from_single_file(
2062            "juggernaut-xl-v9".to_string(),
2063            single_file.clone(),
2064            clip_l_tok.clone(),
2065            clip_g_tok.clone(),
2066            Scheduler::Ddim,
2067            false,
2068            LoadStrategy::Eager,
2069            0,
2070            None,
2071        )
2072        .expect("constructor");
2073
2074        std::env::set_var("MOLD_DEVICE", "cpu");
2075        let err = SDXLEngine::load(&mut engine)
2076            .expect_err("synthetic checkpoint can't satisfy SDXL's full tensor set");
2077        std::env::remove_var("MOLD_DEVICE");
2078
2079        let msg = err.to_string();
2080        assert!(
2081            msg.contains("single-file") || msg.contains("rename rule"),
2082            "expected error from the single-file load layer, got: {msg}",
2083        );
2084
2085        let _ = std::fs::remove_file(single_file);
2086        let _ = std::fs::remove_file(clip_l_tok);
2087        let _ = std::fs::remove_file(clip_g_tok);
2088    }
2089
2090    /// Real-shape SDXL `load()` smoke — see SD15 sibling for the rationale
2091    /// for `#[ignore]`. 2.10 UAT replaces this with a real Pony / Juggernaut
2092    /// XL pull + load + 4-step generation.
2093    #[test]
2094    #[ignore]
2095    fn from_single_file_real_shape_load_smoke() {
2096        // Implementation deferred to the candle-fork-accessor follow-up.
2097    }
2098
2099    /// Sequential single-file dispatch must hand CLIP-L construction to
2100    /// `SingleFileBackend` (which translates diffusers `text_model.X` keys
2101    /// through `SdxlRemap` into A1111 reads), NOT to candle's diffusers-keyed
2102    /// `build_clip_transformer` (which calls `from_mmaped_safetensors` and
2103    /// errors with "cannot find tensor text_model.embeddings.token_embedding.weight"
2104    /// against an A1111-named Civitai checkpoint). The synthetic checkpoint
2105    /// has only one CLIP-L key, so construction *will* fail — but the
2106    /// failure must surface from our backend (no rename rule for the
2107    /// missing diffusers keys), not from the diffusers loader.
2108    ///
2109    /// Locks bug 1 from `tasks/catalog-run-bridge-option-c-handoff.md`:
2110    /// `mold run --local cv:1759168 "<prompt>"` (default sequential
2111    /// strategy) was bailing with the diffusers-loader error before this
2112    /// fix because `generate_sequential` ignored `single_file_path` and
2113    /// passed the A1111-named single-file path straight to
2114    /// `stable_diffusion::build_clip_transformer`.
2115    #[test]
2116    fn build_clip_l_single_file_dispatches_through_backend_not_diffusers_loader() {
2117        let single_file = synth_sdxl_single_file("seq-clip-l-dispatch");
2118        let remap = SDXLEngine::load_sdxl_remap(&single_file).expect("remap");
2119
2120        let result = SDXLEngine::build_clip_l_single_file(
2121            &single_file,
2122            &remap,
2123            &stable_diffusion::clip::Config::sdxl(),
2124            &Device::Cpu,
2125        );
2126
2127        let err = result.expect_err(
2128            "synthetic CLIP-L is missing token_embedding / position_embedding / \
2129             every encoder layer beyond layer 0 — construction must fail",
2130        );
2131        let msg = err.to_string();
2132        assert!(
2133            !msg.contains("cannot find tensor text_model"),
2134            "expected failure from the SingleFileBackend layer (e.g. 'no rename rule \
2135             for diffusers key text_model.embeddings.token_embedding.weight'); got the \
2136             diffusers `from_mmaped_safetensors` error instead — sequential dispatch \
2137             is still routing through `build_clip_transformer`. Got: {msg}",
2138        );
2139
2140        let _ = std::fs::remove_file(single_file);
2141    }
2142
2143    /// Same dispatch contract as the CLIP-L test, but for CLIP-G — covers
2144    /// the `clip2_config` (1280-dim sdxl2) path.
2145    #[test]
2146    fn build_clip_g_single_file_dispatches_through_backend_not_diffusers_loader() {
2147        let single_file = synth_sdxl_single_file("seq-clip-g-dispatch");
2148        let remap = SDXLEngine::load_sdxl_remap(&single_file).expect("remap");
2149
2150        let result = SDXLEngine::build_clip_g_single_file(
2151            &single_file,
2152            &remap,
2153            &stable_diffusion::clip::Config::sdxl2(),
2154            &Device::Cpu,
2155        );
2156
2157        let err = result.expect_err("synthetic CLIP-G is incomplete");
2158        let msg = err.to_string();
2159        assert!(
2160            !msg.contains("cannot find tensor text_model"),
2161            "expected failure from SingleFileBackend, not diffusers loader. Got: {msg}",
2162        );
2163
2164        let _ = std::fs::remove_file(single_file);
2165    }
2166
2167    /// UNet sequential dispatch parity — same shape as the CLIP tests but
2168    /// for `build_unet_single_file`. Sequential mode loads UNet between
2169    /// the CLIP encoders (drop) and VAE (load); without single-file dispatch
2170    /// it would bail with "cannot find tensor conv_in.weight" or similar.
2171    #[test]
2172    fn build_unet_single_file_dispatches_through_backend_not_diffusers_loader() {
2173        let single_file = synth_sdxl_single_file("seq-unet-dispatch");
2174        let remap = SDXLEngine::load_sdxl_remap(&single_file).expect("remap");
2175
2176        let result = SDXLEngine::build_unet_single_file(
2177            &single_file,
2178            &remap,
2179            &stable_diffusion::StableDiffusionConfig::sdxl(None, None, None),
2180            &Device::Cpu,
2181            DType::F32,
2182        );
2183
2184        let err = result.expect_err("synthetic UNet is incomplete");
2185        let msg = err.to_string();
2186        assert!(
2187            !msg.contains("cannot find tensor conv_in"),
2188            "expected failure from SingleFileBackend, not diffusers loader. Got: {msg}",
2189        );
2190
2191        let _ = std::fs::remove_file(single_file);
2192    }
2193
2194    // ----- Phase 2.6: is_turbo threaded into the single-file constructor -----
2195
2196    #[test]
2197    fn from_single_file_threads_is_turbo_true() {
2198        let single_file = synth_sdxl_single_file("turbo");
2199        let clip_l_tok = std::env::temp_dir().join("mold-sdxl-turbo-clip-l-stub.json");
2200        let clip_g_tok = std::env::temp_dir().join("mold-sdxl-turbo-clip-g-stub.json");
2201
2202        let engine = SDXLEngine::from_single_file(
2203            "sdxl-turbo:fp16".to_string(),
2204            single_file.clone(),
2205            clip_l_tok,
2206            clip_g_tok,
2207            Scheduler::EulerAncestral,
2208            true,
2209            LoadStrategy::Eager,
2210            0,
2211            None,
2212        )
2213        .expect("constructor must accept is_turbo = true");
2214
2215        assert!(
2216            engine.is_turbo,
2217            "is_turbo arg must thread into the engine field — sdxl_config() reads this for VAE_SCALE_TURBO",
2218        );
2219
2220        let _ = std::fs::remove_file(single_file);
2221    }
2222
2223    #[test]
2224    fn from_single_file_threads_is_turbo_false() {
2225        let single_file = synth_sdxl_single_file("standard");
2226        let clip_l_tok = std::env::temp_dir().join("mold-sdxl-std-clip-l-stub.json");
2227        let clip_g_tok = std::env::temp_dir().join("mold-sdxl-std-clip-g-stub.json");
2228
2229        let engine = SDXLEngine::from_single_file(
2230            "sdxl-base:fp16".to_string(),
2231            single_file.clone(),
2232            clip_l_tok,
2233            clip_g_tok,
2234            Scheduler::Ddim,
2235            false,
2236            LoadStrategy::Eager,
2237            0,
2238            None,
2239        )
2240        .expect("constructor must accept is_turbo = false");
2241
2242        assert!(
2243            !engine.is_turbo,
2244            "is_turbo = false must produce a standard-config engine",
2245        );
2246
2247        let _ = std::fs::remove_file(single_file);
2248    }
2249
2250    #[test]
2251    fn single_file_sdxl_vae_defaults_to_f32_to_avoid_black_finetune_decodes() {
2252        unsafe { std::env::remove_var("MOLD_VAE_DTYPE") };
2253        assert_eq!(resolve_sdxl_vae_dtype(DType::F16, true), DType::F32);
2254        assert_eq!(resolve_sdxl_vae_dtype(DType::F16, false), DType::F16);
2255    }
2256
2257    // SDXL cfg_active predicate tests — pin the cfg=1.0 short-circuit so
2258    // a regression to `guidance > 1.0` is caught at the per-pipeline boundary.
2259
2260    #[test]
2261    fn test_cfg_disabled_at_guidance_1_0() {
2262        assert!(!cfg_active(1.0));
2263    }
2264
2265    #[test]
2266    fn test_cfg_disabled_just_below_1_0() {
2267        assert!(!cfg_active(1.0 - 1e-5));
2268    }
2269
2270    #[test]
2271    fn test_cfg_enabled_at_guidance_1_5() {
2272        assert!(cfg_active(1.5));
2273    }
2274
2275    #[test]
2276    fn test_cfg_enabled_at_guidance_7_5() {
2277        assert!(cfg_active(7.5));
2278    }
2279
2280    /// `lora_stack_fingerprint` must produce equal fingerprints for the same
2281    /// (path, scale) pairs and different ones for any change — driving the
2282    /// eager-mode UNet drop in `generate_inner`. This is a pure-function
2283    /// pin so a refactor can't silently drop the comparison check.
2284    #[test]
2285    fn lora_stack_fingerprint_equality_drives_unet_drop() {
2286        let a = mold_core::LoraWeight {
2287            path: "/x.safetensors".into(),
2288            scale: 0.8,
2289        };
2290        let b = mold_core::LoraWeight {
2291            path: "/y.safetensors".into(),
2292            scale: 0.4,
2293        };
2294        let same_a = mold_core::LoraWeight {
2295            path: "/x.safetensors".into(),
2296            scale: 0.8,
2297        };
2298        // Equal stacks → equal fingerprints.
2299        assert_eq!(
2300            lora_stack_fingerprint(&[a.clone(), b.clone()]),
2301            lora_stack_fingerprint(&[same_a.clone(), b.clone()])
2302        );
2303        // Same paths, different scale → different fingerprints.
2304        let scaled = mold_core::LoraWeight {
2305            path: "/x.safetensors".into(),
2306            scale: 0.9,
2307        };
2308        assert_ne!(
2309            lora_stack_fingerprint(std::slice::from_ref(&a)),
2310            lora_stack_fingerprint(std::slice::from_ref(&scaled))
2311        );
2312        // Reordered stack → different fingerprints (FLUX behaviour: the
2313        // merge order is observable through delta accumulation order even
2314        // when the sum is commutative, because alpha/scale normalisation
2315        // is applied per-layer per-spec).
2316        assert_ne!(
2317            lora_stack_fingerprint(&[a.clone(), b.clone()]),
2318            lora_stack_fingerprint(&[b, a])
2319        );
2320    }
2321
2322    /// Regression test for the SDXL prompt-cache key bug: keying only on
2323    /// the positive prompt + guidance (as the original code did) returns
2324    /// stale (uncond_old, cond) when the user changes just the negative.
2325    #[test]
2326    fn sdxl_prompt_cache_distinguishes_negative_prompt_changes() {
2327        use crate::cache::{cfg_prompt_cache_key, store_cached_tensor};
2328
2329        let cache: Mutex<LruCache<CfgPromptCacheKey, CachedTensor>> =
2330            Mutex::new(LruCache::new(DEFAULT_PROMPT_CACHE_CAPACITY));
2331        let device = Device::Cpu;
2332        let dtype = DType::F32;
2333        let embeddings = candle_core::Tensor::zeros((1, 4), dtype, &device).unwrap();
2334
2335        let key_a = cfg_prompt_cache_key("a cat", "blurry", 7.0);
2336        store_cached_tensor(&cache, key_a.clone(), &embeddings).unwrap();
2337
2338        // Same positive + same guidance, different negative → MUST miss.
2339        let key_b = cfg_prompt_cache_key("a cat", "low quality", 7.0);
2340        let restored = restore_cached_tensor(&cache, &key_b, &device, dtype).unwrap();
2341        assert!(
2342            restored.is_none(),
2343            "different negative prompt must miss the cache (silent-wrong-output bug)"
2344        );
2345
2346        // Same key as the insert → MUST hit.
2347        let restored = restore_cached_tensor(&cache, &key_a, &device, dtype).unwrap();
2348        assert!(
2349            restored.is_some(),
2350            "identical (pos, neg, guidance) must hit"
2351        );
2352    }
2353}