Skip to main content

mold_inference/wuerstchen/
pipeline.rs

1use anyhow::{bail, Result};
2use candle_core::{DType, Device, Tensor};
3use candle_nn::VarBuilder;
4use candle_transformers::models::stable_diffusion;
5use candle_transformers::models::wuerstchen::ddpm::{DDPMWScheduler, DDPMWSchedulerConfig};
6use candle_transformers::models::wuerstchen::diffnext::WDiffNeXt;
7use candle_transformers::models::wuerstchen::paella_vq::PaellaVQ;
8use candle_transformers::models::wuerstchen::prior::WPrior;
9use mold_core::{GenerateRequest, GenerateResponse, ImageData, ModelPaths};
10use std::collections::HashMap;
11use std::path::Path;
12use std::sync::{Arc, Mutex};
13use std::time::Instant;
14use tokenizers::Tokenizer;
15
16use crate::cache::{
17    clear_cache, get_or_insert_cached_tensor_pair, restore_cached_tensor_pair, CachedTensorPair,
18    LruCache, DEFAULT_PROMPT_CACHE_CAPACITY,
19};
20use crate::device::{check_memory_budget, memory_status_string, preflight_memory_check};
21use crate::engine::{rand_seed, InferenceEngine, LoadStrategy};
22use crate::engine_base::EngineBase;
23use crate::image::{build_output_metadata, encode_image, update_output_metadata_size};
24use crate::img_utils;
25use crate::progress::{ProgressCallback, ProgressEvent};
26
27/// Wuerstchen v2 prior dimensions.
28const PRIOR_C_IN: usize = 16;
29const PRIOR_C: usize = 1536;
30const PRIOR_C_COND: usize = 1280;
31const PRIOR_C_R: usize = 64;
32const PRIOR_DEPTH: usize = 32;
33const PRIOR_NHEAD: usize = 24;
34
35/// Wuerstchen v2 decoder (Stage B) dimensions.
36const DECODER_C_IN: usize = 4;
37const DECODER_C_OUT: usize = 4;
38const DECODER_C_R: usize = 64;
39const DECODER_C_COND: usize = 1024;
40const DECODER_CLIP_EMBD: usize = 1024;
41const DECODER_PATCH_SIZE: usize = 2;
42
43/// Latent compression ratio for Stage C (prior).
44/// Wuerstchen operates in a 42x compressed latent space.
45const LATENT_DIM_SCALE: f64 = 42.67;
46
47/// Scale factor from Prior output spatial dims to Decoder latent dims.
48const LATENT_DIM_SCALE_DECODER: f64 = 10.67;
49
50/// VQ-GAN output scaling factor: decoder latents are multiplied by this before VQ-GAN decode.
51/// For img2img, VQ-GAN encode output is divided by this to get decoder latent space.
52const VQGAN_SCALE: f64 = 0.3764;
53
54/// Loaded Wuerstchen model components, ready for inference.
55struct LoadedWuerstchen {
56    /// None after being dropped for VQ-GAN decode VRAM; reloaded on next generate.
57    prior: Option<WPrior>,
58    /// None after being dropped for VQ-GAN decode VRAM; reloaded on next generate.
59    decoder: Option<WDiffNeXt>,
60    vqgan: PaellaVQ,
61    prior_clip: stable_diffusion::clip::ClipTextTransformer,
62    decoder_clip: stable_diffusion::clip::ClipTextTransformer,
63    prior_tokenizer: Arc<Tokenizer>,
64    decoder_tokenizer: Arc<Tokenizer>,
65    device: Device,
66    clip_device: Device,
67    dtype: DType,
68}
69
70/// Wuerstchen v2 inference engine.
71///
72/// Three-stage cascade: CLIP-G encode -> Prior (Stage C) -> Decoder (Stage B) -> VQ-GAN (Stage A).
73pub struct WuerstchenEngine {
74    base: EngineBase<LoadedWuerstchen>,
75    prompt_cache: Mutex<LruCache<String, CachedTensorPair>>,
76    pending_placement: Option<mold_core::types::DevicePlacement>,
77    shared_pool: Option<Arc<Mutex<crate::shared_pool::SharedPool>>>,
78}
79
80impl WuerstchenEngine {
81    fn debug_tensor_stats(name: &str, tensor: &Tensor) {
82        if std::env::var_os("MOLD_WUERSTCHEN_DEBUG").is_none() {
83            return;
84        }
85
86        let stats = || -> Result<String> {
87            let dims = tensor.dims().to_vec();
88            let dtype = tensor.dtype();
89            let flat = tensor
90                .to_device(&Device::Cpu)?
91                .to_dtype(DType::F32)?
92                .flatten_all()?
93                .to_vec1::<f32>()?;
94
95            if flat.is_empty() {
96                return Ok(format!(
97                    "[wuerstchen-debug] {name}: shape={dims:?} dtype={dtype:?} <empty>"
98                ));
99            }
100
101            let mut min = f32::INFINITY;
102            let mut max = f32::NEG_INFINITY;
103            let mut sum = 0.0f64;
104            let mut sum_sq = 0.0f64;
105            let mut nan_count = 0usize;
106            let mut inf_count = 0usize;
107            let mut finite_count = 0usize;
108
109            for &value in &flat {
110                if value.is_nan() {
111                    nan_count += 1;
112                    continue;
113                }
114                if value.is_infinite() {
115                    inf_count += 1;
116                    continue;
117                }
118                min = min.min(value);
119                max = max.max(value);
120                let value = value as f64;
121                sum += value;
122                sum_sq += value * value;
123                finite_count += 1;
124            }
125
126            let (mean, std) = if finite_count > 0 {
127                let mean = sum / finite_count as f64;
128                let variance = (sum_sq / finite_count as f64) - (mean * mean);
129                (mean, variance.max(0.0).sqrt())
130            } else {
131                (f64::NAN, f64::NAN)
132            };
133            let checksum16: f64 = flat.iter().take(16).map(|&v| v as f64).sum();
134
135            Ok(format!(
136                "[wuerstchen-debug] {name}: shape={dims:?} dtype={dtype:?} min={min:.4} max={max:.4} mean={mean:.4} std={std:.4} nan={nan_count} inf={inf_count} checksum16={checksum16:.4}"
137            ))
138        };
139
140        match stats() {
141            Ok(message) => eprintln!("{message}"),
142            Err(err) => eprintln!("[wuerstchen-debug] {name}: <failed: {err}>"),
143        }
144    }
145
146    fn prompt_cache_key(
147        prompt: &str,
148        negative_prompt: &str,
149        use_prior_cfg: bool,
150        use_decoder_cfg: bool,
151    ) -> String {
152        format!(
153            "{prompt}\u{1f}{negative_prompt}\u{1f}prior_cfg={use_prior_cfg}\u{1f}decoder_cfg={use_decoder_cfg}"
154        )
155    }
156
157    fn pad_token_id(
158        tokenizer: &tokenizers::Tokenizer,
159        clip_config: &stable_diffusion::clip::Config,
160    ) -> Result<u32> {
161        let vocab = tokenizer.get_vocab(true);
162        let token = clip_config.pad_with.as_deref().unwrap_or("<|endoftext|>");
163        vocab
164            .get(token)
165            .copied()
166            .ok_or_else(|| anyhow::anyhow!("tokenizer missing pad/eos token '{token}'"))
167    }
168
169    fn encode_clip_prompt(
170        clip: &stable_diffusion::clip::ClipTextTransformer,
171        tokenizer: &tokenizers::Tokenizer,
172        clip_config: &stable_diffusion::clip::Config,
173        prompt: &str,
174        device: &Device,
175        dtype: DType,
176    ) -> Result<Tensor> {
177        let (tokens, tokens_len) = Self::tokenize(
178            tokenizer,
179            prompt,
180            clip_config.max_position_embeddings,
181            clip_config,
182            device,
183        )?;
184        Ok(clip
185            .forward_with_mask(&tokens, tokens_len - 1)?
186            .to_dtype(dtype)?)
187    }
188
189    fn decoder_guidance() -> f64 {
190        std::env::var("MOLD_WUERSTCHEN_DECODER_GUIDANCE")
191            .ok()
192            .and_then(|value| value.parse::<f64>().ok())
193            .unwrap_or(0.0)
194    }
195
196    fn effective_prior_steps(requested_steps: usize) -> usize {
197        if requested_steps < 10 {
198            // Very low step counts produce noise; warn but respect the request.
199            tracing::warn!(
200                steps = requested_steps,
201                "Wuerstchen prior works best with ≥20 steps"
202            );
203        }
204        requested_steps
205    }
206
207    pub fn new(
208        model_name: String,
209        paths: ModelPaths,
210        load_strategy: LoadStrategy,
211        gpu_ordinal: usize,
212        shared_pool: Option<Arc<Mutex<crate::shared_pool::SharedPool>>>,
213    ) -> Self {
214        Self {
215            base: EngineBase::new(model_name, paths, load_strategy, gpu_ordinal),
216            prompt_cache: Mutex::new(LruCache::new(DEFAULT_PROMPT_CACHE_CAPACITY)),
217            pending_placement: None,
218            shared_pool,
219        }
220    }
221
222    fn load_clip_tokenizer(&self, path: &Path, label: &str) -> Result<Arc<Tokenizer>> {
223        if let Some(shared_pool) = &self.shared_pool {
224            return shared_pool.lock().unwrap().load_tokenizer(path);
225        }
226        Tokenizer::from_file(path)
227            .map(Arc::new)
228            .map_err(|e| anyhow::anyhow!("failed to load {label} tokenizer: {e}"))
229    }
230
231    fn load_vqgan_cpu_tensors(&self) -> Result<Option<Arc<HashMap<String, Tensor>>>> {
232        let Some(shared_pool) = &self.shared_pool else {
233            return Ok(None);
234        };
235        shared_pool
236            .lock()
237            .unwrap()
238            .load_safetensors_cpu_tensors(std::slice::from_ref(&self.base.paths.vae))
239    }
240
241    fn load_vqgan_var_builder<'a>(
242        &self,
243        dtype: DType,
244        device: &Device,
245        component: &str,
246    ) -> Result<VarBuilder<'a>> {
247        if let Some(tensors) = self.load_vqgan_cpu_tensors()? {
248            return Ok(crate::encoders::park::varbuilder_from_parked(
249                tensors.as_ref(),
250                dtype,
251                device,
252            ));
253        }
254
255        crate::weight_loader::load_safetensors_with_progress(
256            std::slice::from_ref(&self.base.paths.vae),
257            dtype,
258            device,
259            component,
260            &self.base.progress,
261        )
262    }
263
264    #[allow(clippy::too_many_arguments)]
265    fn encode_prompt_pair_cached(
266        &self,
267        prior_clip: &stable_diffusion::clip::ClipTextTransformer,
268        prior_tokenizer: &tokenizers::Tokenizer,
269        decoder_clip: &stable_diffusion::clip::ClipTextTransformer,
270        decoder_tokenizer: &tokenizers::Tokenizer,
271        prompt: &str,
272        negative_prompt: &str,
273        device: &Device,
274        clip_device: &Device,
275        dtype: DType,
276        prior_guidance: f64,
277        decoder_guidance: f64,
278    ) -> Result<(Tensor, Tensor)> {
279        let use_prior_cfg = prior_guidance > 1.0;
280        let use_decoder_cfg = decoder_guidance > 1.0;
281        let prior_clip_config = stable_diffusion::clip::Config::wuerstchen_prior();
282        let dec_clip_config = stable_diffusion::clip::Config::wuerstchen();
283        let cache_key =
284            Self::prompt_cache_key(prompt, negative_prompt, use_prior_cfg, use_decoder_cfg);
285        let ((prior_text_embeddings, decoder_text_embeddings), cache_hit) =
286            get_or_insert_cached_tensor_pair(&self.prompt_cache, cache_key, device, dtype, || {
287                self.base
288                    .progress
289                    .stage_start("Encoding prompt (Prior CLIP-G, 1280-dim)");
290                let encode_start = Instant::now();
291                let prior_text_embeddings = Self::encode_clip_prompt(
292                    prior_clip,
293                    prior_tokenizer,
294                    &prior_clip_config,
295                    prompt,
296                    clip_device,
297                    dtype,
298                )?;
299                let prior_text_embeddings = if use_prior_cfg {
300                    let prior_negative_embeddings = Self::encode_clip_prompt(
301                        prior_clip,
302                        prior_tokenizer,
303                        &prior_clip_config,
304                        negative_prompt,
305                        clip_device,
306                        dtype,
307                    )?;
308                    Tensor::cat(&[&prior_text_embeddings, &prior_negative_embeddings], 0)?
309                } else {
310                    prior_text_embeddings
311                };
312                self.base.progress.stage_done(
313                    "Encoding prompt (Prior CLIP-G, 1280-dim)",
314                    encode_start.elapsed(),
315                );
316
317                self.base
318                    .progress
319                    .stage_start("Encoding prompt (Decoder CLIP, 1024-dim)");
320                let dec_encode_start = Instant::now();
321                let decoder_text_embeddings = Self::encode_clip_prompt(
322                    decoder_clip,
323                    decoder_tokenizer,
324                    &dec_clip_config,
325                    prompt,
326                    clip_device,
327                    dtype,
328                )?;
329                let decoder_text_embeddings = if use_decoder_cfg {
330                    let decoder_negative_embeddings = Self::encode_clip_prompt(
331                        decoder_clip,
332                        decoder_tokenizer,
333                        &dec_clip_config,
334                        negative_prompt,
335                        clip_device,
336                        dtype,
337                    )?;
338                    Tensor::cat(&[&decoder_text_embeddings, &decoder_negative_embeddings], 0)?
339                } else {
340                    decoder_text_embeddings
341                };
342                self.base.progress.stage_done(
343                    "Encoding prompt (Decoder CLIP, 1024-dim)",
344                    dec_encode_start.elapsed(),
345                );
346                let prior_text_embeddings = prior_text_embeddings.to_device(device)?;
347                let decoder_text_embeddings = decoder_text_embeddings.to_device(device)?;
348                Ok((prior_text_embeddings, decoder_text_embeddings))
349            })?;
350        if cache_hit {
351            self.base.progress.cache_hit("prompt conditioning");
352        }
353        Ok((prior_text_embeddings, decoder_text_embeddings))
354    }
355
356    /// Validate and return required Wuerstchen paths.
357    /// Returns (decoder_path, prior_clip_encoder, prior_clip_tokenizer, decoder_clip_encoder, decoder_clip_tokenizer)
358    fn validate_paths(
359        &self,
360    ) -> Result<(
361        std::path::PathBuf,
362        std::path::PathBuf,
363        std::path::PathBuf,
364        std::path::PathBuf,
365        std::path::PathBuf,
366    )> {
367        let decoder = self
368            .base
369            .paths
370            .decoder
371            .as_ref()
372            .ok_or_else(|| anyhow::anyhow!("Decoder (Stage B) path required for Wuerstchen"))?
373            .clone();
374        // Prior CLIP-G (1280-dim) — stored in clip_encoder_2
375        let prior_clip_encoder = self
376            .base
377            .paths
378            .clip_encoder_2
379            .as_ref()
380            .ok_or_else(|| anyhow::anyhow!("Prior CLIP-G encoder path required for Wuerstchen"))?
381            .clone();
382        let prior_clip_tokenizer = self
383            .base
384            .paths
385            .clip_tokenizer_2
386            .as_ref()
387            .ok_or_else(|| anyhow::anyhow!("Prior CLIP-G tokenizer path required for Wuerstchen"))?
388            .clone();
389        // Decoder CLIP (1024-dim) — stored in clip_encoder.
390        // Fall back to Prior CLIP if decoder CLIP not available (old configs from
391        // before the dual-CLIP change). Quality will be degraded but won't crash.
392        let decoder_clip_encoder = self.base.paths.clip_encoder.clone().unwrap_or_else(|| {
393            tracing::warn!(
394                "Decoder CLIP encoder path not configured — falling back to Prior CLIP. \
395                     Run `mold rm wuerstchen-v2:fp16 && mold pull wuerstchen-v2:fp16` to fix."
396            );
397            prior_clip_encoder.clone()
398        });
399        let decoder_clip_tokenizer = self
400            .base
401            .paths
402            .clip_tokenizer
403            .clone()
404            .unwrap_or_else(|| prior_clip_tokenizer.clone());
405
406        for (label, path) in [
407            ("prior (Stage C)", &self.base.paths.transformer),
408            ("decoder (Stage B)", &decoder),
409            ("vqgan (Stage A)", &self.base.paths.vae),
410            ("prior clip_encoder", &prior_clip_encoder),
411            ("prior clip_tokenizer", &prior_clip_tokenizer),
412            ("decoder clip_encoder", &decoder_clip_encoder),
413            ("decoder clip_tokenizer", &decoder_clip_tokenizer),
414        ] {
415            if !path.exists() {
416                bail!("{label} file not found: {}", path.display());
417            }
418        }
419
420        Ok((
421            decoder,
422            prior_clip_encoder,
423            prior_clip_tokenizer,
424            decoder_clip_encoder,
425            decoder_clip_tokenizer,
426        ))
427    }
428
429    /// Reload Prior and Decoder if they were dropped after VQ-GAN decode.
430    fn reload_models_if_needed(&mut self) -> Result<()> {
431        let needs_reload = self
432            .base
433            .loaded
434            .as_ref()
435            .map(|l| l.prior.is_none() || l.decoder.is_none())
436            .unwrap_or(false);
437
438        if needs_reload {
439            let (decoder_path, _, _, _, _) = self.validate_paths()?;
440            let loaded = self.base.loaded.as_ref().unwrap();
441            let device = loaded.device.clone();
442            let dtype = loaded.dtype;
443            let _ = loaded;
444
445            self.base.progress.stage_start("Reloading Prior (Stage C)");
446            let reload_start = Instant::now();
447            let prior_vb = crate::weight_loader::load_safetensors_with_progress(
448                &[&self.base.paths.transformer],
449                dtype,
450                &device,
451                "Wuerstchen Prior",
452                &self.base.progress,
453            )?;
454            let prior = WPrior::new(
455                PRIOR_C_IN,
456                PRIOR_C,
457                PRIOR_C_COND,
458                PRIOR_C_R,
459                PRIOR_DEPTH,
460                PRIOR_NHEAD,
461                false,
462                prior_vb,
463            )?;
464            self.base
465                .progress
466                .stage_done("Reloading Prior (Stage C)", reload_start.elapsed());
467
468            self.base
469                .progress
470                .stage_start("Reloading Decoder (Stage B)");
471            let reload_start = Instant::now();
472            let decoder_vb = crate::weight_loader::load_safetensors_with_progress(
473                &[&decoder_path],
474                DType::F32,
475                &device,
476                "Wuerstchen Decoder",
477                &self.base.progress,
478            )?;
479            let decoder = WDiffNeXt::new(
480                DECODER_C_IN,
481                DECODER_C_OUT,
482                DECODER_C_R,
483                DECODER_C_COND,
484                DECODER_CLIP_EMBD,
485                DECODER_PATCH_SIZE,
486                false,
487                decoder_vb,
488            )?;
489            self.base
490                .progress
491                .stage_done("Reloading Decoder (Stage B)", reload_start.elapsed());
492
493            let loaded = self.base.loaded.as_mut().unwrap();
494            loaded.prior = Some(prior);
495            loaded.decoder = Some(decoder);
496        }
497        Ok(())
498    }
499
500    /// Load all Wuerstchen model components (Eager mode).
501    pub fn load(&mut self) -> Result<()> {
502        if self.base.loaded.is_some() {
503            return Ok(());
504        }
505
506        if self.base.load_strategy == LoadStrategy::Sequential {
507            return Ok(());
508        }
509
510        let (decoder_path, prior_clip_path, prior_clip_tok_path, dec_clip_path, dec_clip_tok_path) =
511            self.validate_paths()?;
512
513        tracing::info!(model = %self.base.model_name, "loading Wuerstchen model components...");
514
515        let device = crate::device::create_device(self.base.gpu_ordinal, &self.base.progress)?;
516        // Use F16 on GPU for ~2x throughput and ~2x less VRAM.
517        // gen_r_embedding computes sincos basis in F32 internally, then casts to
518        // model dtype before the matmul — patched in candle-transformers-mold 0.9.4.
519        let dtype = if device.is_cpu() {
520            DType::F32
521        } else {
522            DType::F16
523        };
524
525        // Load Prior (Stage C)
526        self.base.progress.stage_start("Loading Prior (Stage C)");
527        let prior_start = Instant::now();
528        let prior_vb = crate::weight_loader::load_safetensors_with_progress(
529            &[&self.base.paths.transformer],
530            dtype,
531            &device,
532            "Wuerstchen Prior",
533            &self.base.progress,
534        )?;
535        let prior = WPrior::new(
536            PRIOR_C_IN,
537            PRIOR_C,
538            PRIOR_C_COND,
539            PRIOR_C_R,
540            PRIOR_DEPTH,
541            PRIOR_NHEAD,
542            false,
543            prior_vb,
544        )?;
545        self.base
546            .progress
547            .stage_done("Loading Prior (Stage C)", prior_start.elapsed());
548
549        // Load Decoder (Stage B) — F32 because the 256x256 latent space
550        // overflows F16 range during denoising (image_embeddings ±200).
551        self.base.progress.stage_start("Loading Decoder (Stage B)");
552        let decoder_start = Instant::now();
553        let decoder_vb = crate::weight_loader::load_safetensors_with_progress(
554            &[&decoder_path],
555            DType::F32,
556            &device,
557            "Wuerstchen Decoder",
558            &self.base.progress,
559        )?;
560        let decoder = WDiffNeXt::new(
561            DECODER_C_IN,
562            DECODER_C_OUT,
563            DECODER_C_R,
564            DECODER_C_COND,
565            DECODER_CLIP_EMBD,
566            DECODER_PATCH_SIZE,
567            false,
568            decoder_vb,
569        )?;
570        self.base
571            .progress
572            .stage_done("Loading Decoder (Stage B)", decoder_start.elapsed());
573
574        // Load VQ-GAN (Stage A) — always F32 for pixel-space decoding
575        self.base.progress.stage_start("Loading VQ-GAN (Stage A)");
576        let vqgan_start = Instant::now();
577        let vqgan_vb = self.load_vqgan_var_builder(DType::F32, &device, "VQ-GAN")?;
578        let vqgan = PaellaVQ::new(vqgan_vb)?;
579        self.base
580            .progress
581            .stage_done("Loading VQ-GAN (Stage A)", vqgan_start.elapsed());
582
583        // Tier 1: honor `placement.text_encoders` (both CLIPs as a group).
584        let tier1 = self
585            .pending_placement
586            .as_ref()
587            .map(|p| p.text_encoders)
588            .unwrap_or_default();
589        let clip_device = crate::device::resolve_device(Some(tier1), || Ok(device.clone()))?;
590
591        // Load Prior CLIP-G encoder (1280-dim, 32 layers)
592        self.base
593            .progress
594            .stage_start("Loading Prior CLIP-G encoder (1280-dim)");
595        let prior_clip_start = Instant::now();
596        let prior_clip_config = stable_diffusion::clip::Config::wuerstchen_prior();
597        let prior_clip = stable_diffusion::build_clip_transformer(
598            &prior_clip_config,
599            &prior_clip_path,
600            &clip_device,
601            DType::F32,
602        )?;
603        self.base.progress.stage_done(
604            "Loading Prior CLIP-G encoder (1280-dim)",
605            prior_clip_start.elapsed(),
606        );
607
608        // Load Decoder CLIP encoder (1024-dim, 24 layers)
609        self.base
610            .progress
611            .stage_start("Loading Decoder CLIP encoder (1024-dim)");
612        let dec_clip_start = Instant::now();
613        let dec_clip_config = stable_diffusion::clip::Config::wuerstchen();
614        let decoder_clip = stable_diffusion::build_clip_transformer(
615            &dec_clip_config,
616            &dec_clip_path,
617            &clip_device,
618            DType::F32,
619        )?;
620        self.base.progress.stage_done(
621            "Loading Decoder CLIP encoder (1024-dim)",
622            dec_clip_start.elapsed(),
623        );
624
625        // Load tokenizers
626        let prior_tokenizer = self.load_clip_tokenizer(&prior_clip_tok_path, "Prior CLIP-G")?;
627        let decoder_tokenizer = self.load_clip_tokenizer(&dec_clip_tok_path, "Decoder CLIP")?;
628
629        self.base.loaded = Some(LoadedWuerstchen {
630            prior: Some(prior),
631            decoder: Some(decoder),
632            vqgan,
633            prior_clip,
634            decoder_clip,
635            prior_tokenizer,
636            decoder_tokenizer,
637            device,
638            clip_device,
639            dtype,
640        });
641
642        tracing::info!(model = %self.base.model_name, "all Wuerstchen components loaded successfully");
643        Ok(())
644    }
645
646    /// Tokenize a prompt for a CLIP text encoder.
647    /// Returns (tokens_tensor, tokens_len) where tokens_len is the number of
648    /// real tokens before padding (used for forward_with_mask).
649    fn tokenize(
650        tokenizer: &tokenizers::Tokenizer,
651        prompt: &str,
652        max_len: usize,
653        clip_config: &stable_diffusion::clip::Config,
654        device: &Device,
655    ) -> Result<(Tensor, usize)> {
656        let pad_id = Self::pad_token_id(tokenizer, clip_config)?;
657        let encoding = tokenizer
658            .encode(prompt, true)
659            .map_err(|e| anyhow::anyhow!("tokenization failed: {e}"))?;
660        let mut ids = encoding.get_ids().to_vec();
661        ids.truncate(max_len);
662        if ids.is_empty() {
663            ids.push(pad_id);
664        }
665        let tokens_len = ids.len();
666        while ids.len() < max_len {
667            ids.push(pad_id);
668        }
669        Ok((
670            Tensor::new(ids.as_slice(), device)?.unsqueeze(0)?,
671            tokens_len,
672        ))
673    }
674
675    /// Run the Stage C (Prior) denoising loop.
676    #[allow(clippy::too_many_arguments)]
677    fn denoise_prior(
678        &self,
679        prior: &WPrior,
680        text_embeddings: &Tensor,
681        latents: &mut Tensor,
682        // TODO: use for per-step RNG reseeding to close RMSE gap vs candle reference
683        _base_seed: u64,
684        steps: usize,
685        guidance: f64,
686        device: &Device,
687        dtype: DType,
688    ) -> Result<()> {
689        let use_cfg = guidance > 1.0;
690        let scheduler = DDPMWScheduler::new(steps, DDPMWSchedulerConfig::default())?;
691        let timesteps = scheduler.timesteps().to_vec();
692
693        let label = format!("Stage C Prior ({} steps)", timesteps.len() - 1);
694        self.base.progress.stage_start(&label);
695        let start = Instant::now();
696
697        for (step_idx, &t) in timesteps.iter().enumerate() {
698            if step_idx + 1 >= timesteps.len() {
699                break; // last timestep is 0.0, not used for denoising
700            }
701            let step_start = Instant::now();
702
703            let noise_pred = if use_cfg {
704                // CFG: batch [latents, latents] with [text_embeddings, negative_prompt]
705                // text first (index 0), negative/unconditional second (index 1)
706                let latent_input = Tensor::cat(&[&*latents, &*latents], 0)?;
707                let r = (Tensor::ones(2, dtype, device)? * t)?;
708                let pred = prior.forward(&latent_input, &r, text_embeddings)?;
709                let chunks = pred.chunk(2, 0)?;
710                let (pred_text, pred_uncond) = (&chunks[0], &chunks[1]);
711                (pred_uncond + ((pred_text - pred_uncond)? * guidance)?)?
712            } else {
713                let r = (Tensor::ones(1, dtype, device)? * t)?;
714                prior.forward(&*latents, &r, text_embeddings)?
715            };
716
717            *latents = scheduler.step(&noise_pred, t, &*latents)?;
718
719            self.base.progress.emit(ProgressEvent::DenoiseStep {
720                step: step_idx + 1,
721                total: timesteps.len() - 1,
722                elapsed: step_start.elapsed(),
723            });
724        }
725
726        self.base.progress.stage_done(&label, start.elapsed());
727        Ok(())
728    }
729
730    /// Run the Stage B (Decoder) denoising loop.
731    ///
732    /// `image_embeddings` is the scaled Prior output (effnet slot in WDiffNeXt).
733    /// `text_embeddings` is the 1024-dim Decoder CLIP output (clip slot in WDiffNeXt).
734    /// Applies decoder CFG using Diffusers-style conditioning when guidance > 1.0.
735    ///
736    /// `start_step` allows starting from a later timestep for img2img (0 = full txt2img).
737    /// `inpaint_ctx` blends preserved regions back after each step for inpainting.
738    #[allow(clippy::too_many_arguments)]
739    fn denoise_decoder(
740        &self,
741        decoder: &WDiffNeXt,
742        image_embeddings: &Tensor,
743        text_embeddings: &Tensor,
744        latents: &mut Tensor,
745        // TODO: use for per-step RNG reseeding to close RMSE gap vs candle reference
746        _base_seed: u64,
747        steps: usize,
748        start_step: usize,
749        guidance: f64,
750        inpaint_ctx: Option<&img_utils::InpaintContext>,
751        device: &Device,
752        dtype: DType,
753    ) -> Result<()> {
754        let use_cfg = guidance > 1.0;
755        let scheduler = DDPMWScheduler::new(steps, DDPMWSchedulerConfig::default())?;
756        let timesteps = scheduler.timesteps().to_vec();
757        // Drop the final 0.0 timestep (not used for denoising), then skip start_step
758        let active_timesteps = &timesteps[start_step..timesteps.len() - 1];
759
760        let label = format!("Stage B Decoder ({} steps)", active_timesteps.len());
761        self.base.progress.stage_start(&label);
762        let start = Instant::now();
763
764        for (step_idx, &t) in active_timesteps.iter().enumerate() {
765            let step_start = Instant::now();
766
767            let noise_pred = if use_cfg {
768                let latent_input = Tensor::cat(&[&*latents, &*latents], 0)?;
769                let r = (Tensor::ones(2, dtype, device)? * t)?;
770                let effnet_input = Tensor::cat(
771                    &[image_embeddings, &Tensor::zeros_like(image_embeddings)?],
772                    0,
773                )?;
774                let pred =
775                    decoder.forward(&latent_input, &r, &effnet_input, Some(text_embeddings))?;
776                let chunks = pred.chunk(2, 0)?;
777                let (pred_text, pred_uncond) = (&chunks[0], &chunks[1]);
778                (pred_uncond + ((pred_text - pred_uncond)? * guidance)?)?
779            } else {
780                let r = (Tensor::ones(1, dtype, device)? * t)?;
781                decoder.forward(&*latents, &r, image_embeddings, Some(text_embeddings))?
782            };
783
784            *latents = scheduler.step(&noise_pred, t, &*latents)?;
785
786            // Inpainting: blend preserved regions back at current noise level
787            if let Some(ctx) = inpaint_ctx {
788                let noised_original = Self::ddpmw_add_noise(&ctx.original_latents, &ctx.noise, t)?;
789                *latents = crate::img2img::blend_inpaint_latents(&*latents, ctx, &noised_original)?;
790            }
791
792            self.base.progress.emit(ProgressEvent::DenoiseStep {
793                step: step_idx + 1,
794                total: active_timesteps.len(),
795                elapsed: step_start.elapsed(),
796            });
797        }
798
799        self.base.progress.stage_done(&label, start.elapsed());
800        Ok(())
801    }
802
803    /// DDPM noise addition for Wuerstchen's continuous timesteps.
804    ///
805    /// DDPMWScheduler doesn't expose `add_noise()`, so we implement the standard
806    /// DDPM forward process: `noised = sqrt(alpha_cumprod) * original + sqrt(1 - alpha_cumprod) * noise`
807    /// using the same cosine schedule as DDPMWScheduler.
808    fn ddpmw_add_noise(original: &Tensor, noise: &Tensor, t: f64) -> Result<Tensor> {
809        // Replicate DDPMWScheduler::alpha_cumprod with default config (scaler=1.0, s=0.008)
810        let s = 0.008f64;
811        let init_alpha_cumprod = (s / (1.0 + s) * std::f64::consts::PI).cos().powi(2);
812        let alpha_cumprod = ((t + s) / (1.0 + s) * std::f64::consts::PI * 0.5)
813            .cos()
814            .powi(2)
815            / init_alpha_cumprod;
816        let alpha_cumprod = alpha_cumprod.clamp(0.0001, 0.9999);
817
818        let sqrt_alpha = alpha_cumprod.sqrt();
819        let sqrt_one_minus_alpha = (1.0 - alpha_cumprod).sqrt();
820
821        let noised = ((original * sqrt_alpha)? + (noise * sqrt_one_minus_alpha)?)?;
822        Ok(noised)
823    }
824
825    /// Prepare img2img latents: VQ-GAN encode source image, add noise at the start timestep.
826    /// Returns (noised_latents, start_step, encoded_latents, noise).
827    #[allow(clippy::too_many_arguments)]
828    fn prepare_img2img_latents(
829        &self,
830        vqgan: &PaellaVQ,
831        source_bytes: &[u8],
832        width: u32,
833        height: u32,
834        strength: f64,
835        decoder_steps: usize,
836        seed: u64,
837        device: &Device,
838    ) -> Result<(Tensor, usize, Tensor, Tensor)> {
839        self.base
840            .progress
841            .stage_start("Encoding source image (VQ-GAN)");
842        let encode_start = Instant::now();
843
844        // VQ-GAN expects [0, 1] normalized input in F32
845        let source_tensor = img_utils::decode_source_image(
846            source_bytes,
847            width,
848            height,
849            img_utils::NormalizeRange::ZeroToOne,
850            device,
851            DType::F32,
852        )?;
853
854        let encoded = vqgan.encode(&source_tensor)?;
855        // Scale from VQ-GAN latent space to decoder latent space (inverse of decode scaling)
856        let encoded = (&encoded / VQGAN_SCALE)?;
857
858        self.base
859            .progress
860            .stage_done("Encoding source image (VQ-GAN)", encode_start.elapsed());
861
862        let start_step = crate::img2img::img2img_start_index(decoder_steps, strength);
863
864        // Generate deterministic noise matching decoder latent shape
865        let noise = crate::engine::seeded_randn(seed, encoded.dims(), device, DType::F32)?;
866
867        // Build scheduler to get timesteps for noise addition
868        let scheduler = DDPMWScheduler::new(decoder_steps, DDPMWSchedulerConfig::default())?;
869        let timesteps = scheduler.timesteps().to_vec();
870
871        // Add noise at the start timestep
872        let noised = if start_step < timesteps.len() - 1 {
873            Self::ddpmw_add_noise(&encoded, &noise, timesteps[start_step])?
874        } else {
875            encoded.clone()
876        };
877
878        tracing::info!(
879            start_step,
880            total_steps = decoder_steps,
881            strength,
882            "img2img: starting decoder from step {start_step}"
883        );
884
885        Ok((noised, start_step, encoded, noise))
886    }
887
888    /// Generate an image using sequential loading strategy.
889    fn generate_sequential(&mut self, req: &GenerateRequest) -> Result<GenerateResponse> {
890        let (decoder_path, prior_clip_path, prior_clip_tok_path, dec_clip_path, dec_clip_tok_path) =
891            self.validate_paths()?;
892
893        if let Some(warning) = check_memory_budget(&self.base.paths, LoadStrategy::Sequential) {
894            self.base.progress.info(&warning);
895        }
896
897        let device = crate::device::create_device(self.base.gpu_ordinal, &self.base.progress)?;
898        // Use F16 on GPU for ~2x throughput on the Prior stage.
899        // Decoder and VQ-GAN use F32 explicitly (see their load calls below).
900        let dtype = if device.is_cpu() {
901            DType::F32
902        } else {
903            DType::F16
904        };
905
906        let start = Instant::now();
907        let seed = req.seed.unwrap_or_else(rand_seed);
908        let width = req.width as usize;
909        let height = req.height as usize;
910        let prior_guidance = req.guidance;
911        let decoder_guidance = Self::decoder_guidance();
912        let negative_prompt = req.negative_prompt.as_deref().unwrap_or("");
913        let prior_steps = Self::effective_prior_steps(req.steps as usize);
914        let decoder_steps = 12;
915
916        tracing::info!(
917            prompt = %req.prompt,
918            seed, width, height,
919            prior_steps,
920            decoder_steps,
921            prior_guidance,
922            decoder_guidance,
923            "starting sequential Wuerstchen generation"
924        );
925
926        self.base
927            .progress
928            .info("Using sequential loading (load-use-drop) to minimize peak memory");
929
930        // --- Phase 1: Encode prompt (check cache first to skip encoder load) ---
931        let use_prior_cfg = prior_guidance > 1.0;
932        let use_decoder_cfg = decoder_guidance > 1.0;
933        let cache_key =
934            Self::prompt_cache_key(&req.prompt, negative_prompt, use_prior_cfg, use_decoder_cfg);
935        let (prior_text_embeddings, decoder_text_embeddings) =
936            if let Some((prior_emb, decoder_emb)) =
937                restore_cached_tensor_pair(&self.prompt_cache, &cache_key, &device, dtype)?
938            {
939                self.base.progress.cache_hit("prompt conditioning");
940                (prior_emb, decoder_emb)
941            } else {
942                if let Some(status) = memory_status_string() {
943                    self.base.progress.info(&status);
944                }
945
946                let prior_tokenizer =
947                    self.load_clip_tokenizer(&prior_clip_tok_path, "Prior CLIP-G")?;
948
949                let tier1 = self
950                    .pending_placement
951                    .as_ref()
952                    .map(|p| p.text_encoders)
953                    .unwrap_or_default();
954                let clip_device =
955                    crate::device::resolve_device(Some(tier1), || Ok(device.clone()))?;
956
957                self.base
958                    .progress
959                    .stage_start("Loading Prior CLIP-G encoder (1280-dim)");
960                let clip_start = Instant::now();
961                let prior_clip_config = stable_diffusion::clip::Config::wuerstchen_prior();
962                let prior_clip = stable_diffusion::build_clip_transformer(
963                    &prior_clip_config,
964                    &prior_clip_path,
965                    &clip_device,
966                    DType::F32,
967                )?;
968                self.base.progress.stage_done(
969                    "Loading Prior CLIP-G encoder (1280-dim)",
970                    clip_start.elapsed(),
971                );
972
973                let decoder_tokenizer =
974                    self.load_clip_tokenizer(&dec_clip_tok_path, "Decoder CLIP")?;
975
976                self.base
977                    .progress
978                    .stage_start("Loading Decoder CLIP encoder (1024-dim)");
979                let dec_clip_start = Instant::now();
980                let dec_clip_config = stable_diffusion::clip::Config::wuerstchen();
981                let decoder_clip = stable_diffusion::build_clip_transformer(
982                    &dec_clip_config,
983                    &dec_clip_path,
984                    &clip_device,
985                    DType::F32,
986                )?;
987                self.base.progress.stage_done(
988                    "Loading Decoder CLIP encoder (1024-dim)",
989                    dec_clip_start.elapsed(),
990                );
991
992                let (prior_emb, decoder_emb) = self.encode_prompt_pair_cached(
993                    &prior_clip,
994                    &prior_tokenizer,
995                    &decoder_clip,
996                    &decoder_tokenizer,
997                    &req.prompt,
998                    negative_prompt,
999                    &device,
1000                    &clip_device,
1001                    dtype,
1002                    prior_guidance,
1003                    decoder_guidance,
1004                )?;
1005
1006                drop(prior_clip);
1007                drop(prior_tokenizer);
1008                self.base.progress.info("Freed Prior CLIP-G encoder");
1009
1010                (prior_emb, decoder_emb)
1011            };
1012        Self::debug_tensor_stats("prior_text_embeddings", &prior_text_embeddings);
1013        Self::debug_tensor_stats("decoder_text_embeddings", &decoder_text_embeddings);
1014        tracing::info!("CLIP encoders processed (sequential mode)");
1015
1016        let is_img2img = req.source_image.is_some();
1017
1018        // --- Phase 2: img2img path (VQ-GAN encode, skip Prior) or txt2img path (Prior) ---
1019        let (image_embeddings, mut decoder_latents, decoder_start_step, inpaint_ctx) =
1020            if let Some(ref source_bytes) = req.source_image {
1021                self.base
1022                    .progress
1023                    .info("img2img mode: skipping Prior, encoding source via VQ-GAN");
1024
1025                // Load VQ-GAN for encoding
1026                self.base.progress.stage_start("Loading VQ-GAN (Stage A)");
1027                let vqgan_start = Instant::now();
1028                let vqgan_vb = self.load_vqgan_var_builder(DType::F32, &device, "VQ-GAN")?;
1029                let vqgan = PaellaVQ::new(vqgan_vb)?;
1030                self.base
1031                    .progress
1032                    .stage_done("Loading VQ-GAN (Stage A)", vqgan_start.elapsed());
1033
1034                let (noised, start_step, encoded, noise) = self.prepare_img2img_latents(
1035                    &vqgan,
1036                    source_bytes,
1037                    req.width,
1038                    req.height,
1039                    req.strength,
1040                    decoder_steps,
1041                    seed,
1042                    &device,
1043                )?;
1044
1045                let (_, _, enc_h, enc_w) = encoded.dims4()?;
1046                let inpaint_ctx = crate::img2img::maybe_build_inpaint_context(
1047                    req.mask_image.as_deref(),
1048                    &encoded,
1049                    &noise,
1050                    enc_h,
1051                    enc_w,
1052                    &device,
1053                    DType::F32,
1054                )?;
1055
1056                // Use zeros for effnet conditioning (no Prior output)
1057                // The Decoder will rely on text conditioning + noised latents
1058                let (_, _, _enc_h, _enc_w) = noised.dims4()?;
1059                let prior_latent_h = (height as f64 / LATENT_DIM_SCALE).ceil() as usize;
1060                let prior_latent_w = (width as f64 / LATENT_DIM_SCALE).ceil() as usize;
1061                let image_embeddings = Tensor::zeros(
1062                    (1, PRIOR_C_IN, prior_latent_h, prior_latent_w),
1063                    DType::F32,
1064                    &device,
1065                )?;
1066
1067                drop(vqgan);
1068                device.synchronize()?;
1069                self.base
1070                    .progress
1071                    .info("Freed VQ-GAN (will reload for decode)");
1072
1073                Self::debug_tensor_stats("decoder_latents_init", &noised);
1074                (image_embeddings, noised, start_step, inpaint_ctx)
1075            } else {
1076                // --- txt2img: run Prior (Stage C) ---
1077                let prior_size = std::fs::metadata(&self.base.paths.transformer)
1078                    .map(|m| m.len())
1079                    .unwrap_or(0);
1080                let prior_activation_budget = crate::device::activation_bytes(
1081                    req.width,
1082                    req.height,
1083                    if req.guidance > 1.0 { 2 } else { 1 },
1084                    crate::device::dtype_bytes(dtype),
1085                    crate::device::ActivationFamily::Wuerstchen,
1086                );
1087                preflight_memory_check("Prior (Stage C)", prior_size, prior_activation_budget)?;
1088                if let Some(status) = memory_status_string() {
1089                    self.base.progress.info(&status);
1090                }
1091
1092                self.base.progress.stage_start("Loading Prior (Stage C)");
1093                let prior_start = Instant::now();
1094                let prior_vb = crate::weight_loader::load_safetensors_with_progress(
1095                    &[&self.base.paths.transformer],
1096                    dtype,
1097                    &device,
1098                    "Wuerstchen Prior",
1099                    &self.base.progress,
1100                )?;
1101                let prior = WPrior::new(
1102                    PRIOR_C_IN,
1103                    PRIOR_C,
1104                    PRIOR_C_COND,
1105                    PRIOR_C_R,
1106                    PRIOR_DEPTH,
1107                    PRIOR_NHEAD,
1108                    false,
1109                    prior_vb,
1110                )?;
1111                self.base
1112                    .progress
1113                    .stage_done("Loading Prior (Stage C)", prior_start.elapsed());
1114
1115                // Stage C latent dimensions: 42x compression
1116                let latent_h = (height as f64 / LATENT_DIM_SCALE).ceil() as usize;
1117                let latent_w = (width as f64 / LATENT_DIM_SCALE).ceil() as usize;
1118                device.set_seed(seed)?;
1119                let mut prior_latents =
1120                    Tensor::randn(0f32, 1f32, (1, PRIOR_C_IN, latent_h, latent_w), &device)?
1121                        .to_dtype(dtype)?;
1122                Self::debug_tensor_stats("prior_latents_init", &prior_latents);
1123
1124                self.denoise_prior(
1125                    &prior,
1126                    &prior_text_embeddings,
1127                    &mut prior_latents,
1128                    seed,
1129                    prior_steps,
1130                    prior_guidance,
1131                    &device,
1132                    dtype,
1133                )?;
1134
1135                // Scale prior output: convert from Prior latent space to Decoder conditioning space
1136                Self::debug_tensor_stats("prior_latents_denoised", &prior_latents);
1137                prior_latents = ((prior_latents * 42.)? - 1.)?;
1138                Self::debug_tensor_stats("image_embeddings", &prior_latents);
1139
1140                drop(prior);
1141                device.synchronize()?;
1142                self.base.progress.info("Freed Prior (Stage C)");
1143
1144                // Decoder latent dims derived from prior output spatial dims
1145                let prior_latents = prior_latents.to_dtype(DType::F32)?;
1146                let stage_b_h = (prior_latents.dim(2)? as f64 * LATENT_DIM_SCALE_DECODER) as usize;
1147                let stage_b_w = (prior_latents.dim(3)? as f64 * LATENT_DIM_SCALE_DECODER) as usize;
1148                device.set_seed(seed.wrapping_add(1))?;
1149                let decoder_latents =
1150                    Tensor::randn(0f32, 1f32, (1, 4, stage_b_h, stage_b_w), &device)?;
1151                Self::debug_tensor_stats("decoder_latents_init", &decoder_latents);
1152
1153                (prior_latents, decoder_latents, 0, None)
1154            };
1155        drop(prior_text_embeddings);
1156
1157        // --- Phase 3: Decoder (Stage B) ---
1158        let decoder_size = std::fs::metadata(&decoder_path)
1159            .map(|m| m.len())
1160            .unwrap_or(0);
1161        let decoder_activation_budget = crate::device::activation_bytes(
1162            req.width,
1163            req.height,
1164            if req.guidance > 1.0 { 2 } else { 1 },
1165            crate::device::dtype_bytes(DType::F32),
1166            crate::device::ActivationFamily::Wuerstchen,
1167        );
1168        preflight_memory_check("Decoder (Stage B)", decoder_size, decoder_activation_budget)?;
1169        if let Some(status) = memory_status_string() {
1170            self.base.progress.info(&status);
1171        }
1172
1173        // Decoder uses F32 — the 256x256 latent space overflows F16 range
1174        self.base.progress.stage_start("Loading Decoder (Stage B)");
1175        let dec_start = Instant::now();
1176        let decoder_vb = crate::weight_loader::load_safetensors_with_progress(
1177            &[&decoder_path],
1178            DType::F32,
1179            &device,
1180            "Wuerstchen Decoder",
1181            &self.base.progress,
1182        )?;
1183        let decoder = WDiffNeXt::new(
1184            DECODER_C_IN,
1185            DECODER_C_OUT,
1186            DECODER_C_R,
1187            DECODER_C_COND,
1188            DECODER_CLIP_EMBD,
1189            DECODER_PATCH_SIZE,
1190            false,
1191            decoder_vb,
1192        )?;
1193        self.base
1194            .progress
1195            .stage_done("Loading Decoder (Stage B)", dec_start.elapsed());
1196
1197        // Cast text embeddings to F32 for Decoder
1198        let decoder_text_embeddings = decoder_text_embeddings.to_dtype(DType::F32)?;
1199
1200        self.denoise_decoder(
1201            &decoder,
1202            &image_embeddings,
1203            &decoder_text_embeddings,
1204            &mut decoder_latents,
1205            seed,
1206            decoder_steps,
1207            decoder_start_step,
1208            decoder_guidance,
1209            inpaint_ctx.as_ref(),
1210            &device,
1211            DType::F32,
1212        )?;
1213        Self::debug_tensor_stats("decoder_latents_denoised", &decoder_latents);
1214
1215        drop(decoder);
1216        drop(image_embeddings);
1217        drop(decoder_text_embeddings);
1218        drop(inpaint_ctx);
1219        device.synchronize()?;
1220        self.base.progress.info("Freed Decoder (Stage B)");
1221
1222        // --- Phase 4: VQ-GAN decode (Stage A) ---
1223        // VQ-GAN uses F32 for pixel-space decoding regardless of model dtype
1224        let vqgan_load_label = if is_img2img {
1225            "Reloading VQ-GAN (Stage A)"
1226        } else {
1227            "Loading VQ-GAN (Stage A)"
1228        };
1229        self.base.progress.stage_start(vqgan_load_label);
1230        let vqgan_start = Instant::now();
1231        let vqgan_vb = self.load_vqgan_var_builder(DType::F32, &device, "VQ-GAN")?;
1232        let vqgan = PaellaVQ::new(vqgan_vb)?;
1233        self.base
1234            .progress
1235            .stage_done(vqgan_load_label, vqgan_start.elapsed());
1236
1237        self.base.progress.stage_start("VQ-GAN decode");
1238        let decode_start = Instant::now();
1239        Self::debug_tensor_stats("decoder_latents_pre_vq", &decoder_latents);
1240        let img = vqgan.decode(&(&decoder_latents * VQGAN_SCALE)?)?;
1241        Self::debug_tensor_stats("image_pre_postprocess", &img);
1242        let img = img.clamp(0f32, 1f32)?;
1243        Self::debug_tensor_stats("image_postprocess", &img);
1244        let img = (img * 255.)?.to_dtype(DType::U8)?;
1245        let img = img.squeeze(0)?;
1246        self.base
1247            .progress
1248            .stage_done("VQ-GAN decode", decode_start.elapsed());
1249
1250        // Use actual tensor dims — VQ-GAN output may differ from requested dims
1251        // due to the 42x compression rounding in the cascade.
1252        let (_, actual_h, actual_w) = img.dims3()?;
1253        let mut output_metadata = build_output_metadata(req, seed, None);
1254        update_output_metadata_size(&mut output_metadata, actual_w as u32, actual_h as u32);
1255        let image_bytes = encode_image(
1256            &img,
1257            req.resolved_output_format(),
1258            actual_w as u32,
1259            actual_h as u32,
1260            output_metadata.as_ref(),
1261        )?;
1262
1263        let generation_time_ms = start.elapsed().as_millis() as u64;
1264        tracing::info!(
1265            generation_time_ms,
1266            seed,
1267            "sequential Wuerstchen generation complete"
1268        );
1269
1270        Ok(GenerateResponse {
1271            images: vec![ImageData {
1272                data: image_bytes,
1273                format: req.resolved_output_format(),
1274                width: req.width,
1275                height: req.height,
1276                index: 0,
1277            }],
1278            generation_time_ms,
1279            model: req.model.clone(),
1280            seed_used: seed,
1281            video: None,
1282            gpu: None,
1283        })
1284    }
1285}
1286
1287impl WuerstchenEngine {
1288    fn generate_inner(&mut self, req: &GenerateRequest) -> Result<GenerateResponse> {
1289        if req.scheduler.is_some() {
1290            tracing::warn!("scheduler selection not supported for Wuerstchen, ignoring");
1291        }
1292
1293        if self.base.load_strategy == LoadStrategy::Sequential {
1294            return self.generate_sequential(req);
1295        }
1296
1297        // Reload Prior/Decoder if dropped after previous VQ-GAN decode
1298        self.reload_models_if_needed()?;
1299
1300        let loaded = self
1301            .base
1302            .loaded
1303            .as_ref()
1304            .ok_or_else(|| anyhow::anyhow!("model not loaded — call load() first"))?;
1305
1306        let start = Instant::now();
1307        let seed = req.seed.unwrap_or_else(rand_seed);
1308        let width = req.width as usize;
1309        let height = req.height as usize;
1310        let prior_guidance = req.guidance;
1311        let decoder_guidance = Self::decoder_guidance();
1312        let negative_prompt = req.negative_prompt.as_deref().unwrap_or("");
1313        let prior_steps = Self::effective_prior_steps(req.steps as usize);
1314        let decoder_steps = 12;
1315
1316        tracing::info!(
1317            prompt = %req.prompt,
1318            seed, width, height,
1319            prior_steps,
1320            decoder_steps,
1321            prior_guidance,
1322            decoder_guidance,
1323            "starting Wuerstchen generation"
1324        );
1325
1326        // 1. Encode prompt with Prior CLIP-G (1280-dim)
1327        let (prior_text_embeddings, decoder_text_embeddings) = self.encode_prompt_pair_cached(
1328            &loaded.prior_clip,
1329            &loaded.prior_tokenizer,
1330            &loaded.decoder_clip,
1331            &loaded.decoder_tokenizer,
1332            &req.prompt,
1333            negative_prompt,
1334            &loaded.device,
1335            &loaded.clip_device,
1336            loaded.dtype,
1337            prior_guidance,
1338            decoder_guidance,
1339        )?;
1340        Self::debug_tensor_stats("prior_text_embeddings", &prior_text_embeddings);
1341        Self::debug_tensor_stats("decoder_text_embeddings", &decoder_text_embeddings);
1342
1343        // 2. Prepare latents: img2img (VQ-GAN encode, skip Prior) or txt2img (Prior)
1344        let (image_embeddings, mut decoder_latents, decoder_start_step, inpaint_ctx) =
1345            if let Some(ref source_bytes) = req.source_image {
1346                self.base
1347                    .progress
1348                    .info("img2img mode: skipping Prior, encoding source via VQ-GAN");
1349
1350                let (noised, start_step, encoded, noise) = self.prepare_img2img_latents(
1351                    &loaded.vqgan,
1352                    source_bytes,
1353                    req.width,
1354                    req.height,
1355                    req.strength,
1356                    decoder_steps,
1357                    seed,
1358                    &loaded.device,
1359                )?;
1360
1361                let (_, _, enc_h, enc_w) = encoded.dims4()?;
1362                let inpaint_ctx = crate::img2img::maybe_build_inpaint_context(
1363                    req.mask_image.as_deref(),
1364                    &encoded,
1365                    &noise,
1366                    enc_h,
1367                    enc_w,
1368                    &loaded.device,
1369                    DType::F32,
1370                )?;
1371
1372                // Use zeros for effnet conditioning (no Prior output)
1373                let prior_latent_h = (height as f64 / LATENT_DIM_SCALE).ceil() as usize;
1374                let prior_latent_w = (width as f64 / LATENT_DIM_SCALE).ceil() as usize;
1375                let image_embeddings = Tensor::zeros(
1376                    (1, PRIOR_C_IN, prior_latent_h, prior_latent_w),
1377                    DType::F32,
1378                    &loaded.device,
1379                )?;
1380
1381                Self::debug_tensor_stats("decoder_latents_init", &noised);
1382                (image_embeddings, noised, start_step, inpaint_ctx)
1383            } else {
1384                // txt2img: run Stage C (Prior) to generate image embeddings
1385                let latent_h = (height as f64 / LATENT_DIM_SCALE).ceil() as usize;
1386                let latent_w = (width as f64 / LATENT_DIM_SCALE).ceil() as usize;
1387                loaded.device.set_seed(seed)?;
1388                let mut prior_latents = Tensor::randn(
1389                    0f32,
1390                    1f32,
1391                    (1, PRIOR_C_IN, latent_h, latent_w),
1392                    &loaded.device,
1393                )?
1394                .to_dtype(loaded.dtype)?;
1395                Self::debug_tensor_stats("prior_latents_init", &prior_latents);
1396
1397                let prior = loaded
1398                    .prior
1399                    .as_ref()
1400                    .ok_or_else(|| anyhow::anyhow!("Prior not loaded"))?;
1401                self.denoise_prior(
1402                    prior,
1403                    &prior_text_embeddings,
1404                    &mut prior_latents,
1405                    seed,
1406                    prior_steps,
1407                    prior_guidance,
1408                    &loaded.device,
1409                    loaded.dtype,
1410                )?;
1411
1412                // Scale prior output: convert from Prior latent space to Decoder conditioning space
1413                Self::debug_tensor_stats("prior_latents_denoised", &prior_latents);
1414                prior_latents = ((prior_latents * 42.)? - 1.)?;
1415                Self::debug_tensor_stats("image_embeddings", &prior_latents);
1416
1417                // Stage B (Decoder): decode prior latents to VQ-GAN latent space
1418                let prior_latents = prior_latents.to_dtype(DType::F32)?;
1419                let stage_b_h = (prior_latents.dim(2)? as f64 * LATENT_DIM_SCALE_DECODER) as usize;
1420                let stage_b_w = (prior_latents.dim(3)? as f64 * LATENT_DIM_SCALE_DECODER) as usize;
1421                loaded.device.set_seed(seed.wrapping_add(1))?;
1422                let decoder_latents =
1423                    Tensor::randn(0f32, 1f32, (1, 4, stage_b_h, stage_b_w), &loaded.device)?;
1424                Self::debug_tensor_stats("decoder_latents_init", &decoder_latents);
1425
1426                (prior_latents, decoder_latents, 0, None)
1427            };
1428
1429        // 3. Stage B (Decoder): denoise
1430        // Cast text embeddings to F32 for Decoder (F16 overflows)
1431        let decoder_text_embeddings = decoder_text_embeddings.to_dtype(DType::F32)?;
1432
1433        let decoder = loaded
1434            .decoder
1435            .as_ref()
1436            .ok_or_else(|| anyhow::anyhow!("Decoder not loaded"))?;
1437        self.denoise_decoder(
1438            decoder,
1439            &image_embeddings,
1440            &decoder_text_embeddings,
1441            &mut decoder_latents,
1442            seed,
1443            decoder_steps,
1444            decoder_start_step,
1445            decoder_guidance,
1446            inpaint_ctx.as_ref(),
1447            &loaded.device,
1448            DType::F32,
1449        )?;
1450        Self::debug_tensor_stats("decoder_latents_denoised", &decoder_latents);
1451
1452        // Drop Prior and Decoder before VQ-GAN decode to free VRAM.
1453        drop(inpaint_ctx);
1454        let _ = loaded;
1455        let loaded = self.base.loaded.as_mut().unwrap();
1456        loaded.prior = None;
1457        loaded.decoder = None;
1458        loaded.device.synchronize()?;
1459        tracing::info!("Prior + Decoder dropped to free VRAM for VQ-GAN decode");
1460        let _ = loaded;
1461        let loaded = self.base.loaded.as_ref().unwrap();
1462
1463        // 4. Stage A (VQ-GAN): decode to pixel space
1464        self.base.progress.stage_start("VQ-GAN decode");
1465        let decode_start = Instant::now();
1466        Self::debug_tensor_stats("decoder_latents_pre_vq", &decoder_latents);
1467        let img = loaded.vqgan.decode(&(&decoder_latents * VQGAN_SCALE)?)?;
1468        Self::debug_tensor_stats("image_pre_postprocess", &img);
1469        let img = img.clamp(0f32, 1f32)?;
1470        Self::debug_tensor_stats("image_postprocess", &img);
1471        let img = (img * 255.)?.to_dtype(DType::U8)?;
1472        let img = img.squeeze(0)?;
1473        self.base
1474            .progress
1475            .stage_done("VQ-GAN decode", decode_start.elapsed());
1476
1477        // 5. Encode to image format
1478        // Use actual tensor dims — VQ-GAN output may differ from requested dims
1479        // due to the 42x compression rounding in the cascade.
1480        let (_, actual_h, actual_w) = img.dims3()?;
1481        let mut output_metadata = build_output_metadata(req, seed, None);
1482        update_output_metadata_size(&mut output_metadata, actual_w as u32, actual_h as u32);
1483        let image_bytes = encode_image(
1484            &img,
1485            req.resolved_output_format(),
1486            actual_w as u32,
1487            actual_h as u32,
1488            output_metadata.as_ref(),
1489        )?;
1490
1491        let generation_time_ms = start.elapsed().as_millis() as u64;
1492        tracing::info!(generation_time_ms, seed, "Wuerstchen generation complete");
1493
1494        Ok(GenerateResponse {
1495            images: vec![ImageData {
1496                data: image_bytes,
1497                format: req.resolved_output_format(),
1498                width: req.width,
1499                height: req.height,
1500                index: 0,
1501            }],
1502            generation_time_ms,
1503            model: req.model.clone(),
1504            seed_used: seed,
1505            video: None,
1506            gpu: None,
1507        })
1508    }
1509}
1510
1511impl InferenceEngine for WuerstchenEngine {
1512    fn generate(&mut self, req: &GenerateRequest) -> Result<GenerateResponse> {
1513        self.pending_placement = req.placement.clone();
1514        let result = self.generate_inner(req);
1515        self.pending_placement = None;
1516        result
1517    }
1518
1519    fn model_name(&self) -> &str {
1520        self.base.model_name()
1521    }
1522
1523    fn is_loaded(&self) -> bool {
1524        self.base.is_loaded()
1525    }
1526
1527    fn load(&mut self) -> Result<()> {
1528        WuerstchenEngine::load(self)
1529    }
1530
1531    fn unload(&mut self) {
1532        self.base.unload();
1533        clear_cache(&self.prompt_cache);
1534    }
1535
1536    fn set_on_progress(&mut self, callback: ProgressCallback) {
1537        self.base.set_on_progress(callback);
1538    }
1539
1540    fn clear_on_progress(&mut self) {
1541        self.base.clear_on_progress();
1542    }
1543
1544    fn model_paths(&self) -> Option<&mold_core::ModelPaths> {
1545        Some(&self.base.paths)
1546    }
1547}
1548
1549#[cfg(test)]
1550mod tests {
1551    use super::*;
1552    use crate::engine::LoadStrategy;
1553    use crate::shared_pool::SharedPool;
1554    use mold_core::ModelPaths;
1555    use safetensors::tensor::{serialize_to_file, Dtype as SafeDtype, TensorView};
1556    use std::fs;
1557    use std::path::{Path, PathBuf};
1558    use std::sync::{Arc, Mutex};
1559    use std::time::{SystemTime, UNIX_EPOCH};
1560    use tokenizers::models::bpe::BPE;
1561
1562    fn temp_test_dir(prefix: &str) -> PathBuf {
1563        let suffix = SystemTime::now()
1564            .duration_since(UNIX_EPOCH)
1565            .unwrap()
1566            .as_nanos();
1567        let dir = std::env::temp_dir().join(format!("{prefix}-{}-{suffix}", std::process::id()));
1568        fs::create_dir_all(&dir).unwrap();
1569        dir
1570    }
1571
1572    fn touch(dir: &Path, name: &str) -> PathBuf {
1573        let path = dir.join(name);
1574        fs::write(&path, b"test").unwrap();
1575        path
1576    }
1577
1578    fn wuerstchen_model_paths(
1579        transformer: PathBuf,
1580        decoder: Option<PathBuf>,
1581        vae: PathBuf,
1582        prior_clip_encoder: Option<PathBuf>,
1583        prior_clip_tokenizer: Option<PathBuf>,
1584        decoder_clip_encoder: Option<PathBuf>,
1585        decoder_clip_tokenizer: Option<PathBuf>,
1586    ) -> ModelPaths {
1587        ModelPaths {
1588            transformer,
1589            transformer_shards: vec![],
1590            vae,
1591            spatial_upscaler: None,
1592            temporal_upscaler: None,
1593            distilled_lora: None,
1594            t5_encoder: None,
1595            clip_encoder: decoder_clip_encoder,
1596            t5_tokenizer: None,
1597            clip_tokenizer: decoder_clip_tokenizer,
1598            clip_encoder_2: prior_clip_encoder,
1599            clip_tokenizer_2: prior_clip_tokenizer,
1600            text_encoder_files: vec![],
1601            text_tokenizer: None,
1602            decoder,
1603        }
1604    }
1605
1606    fn test_tokenizer() -> tokenizers::Tokenizer {
1607        let tokenizer_json = r#"{
1608  "version": "1.0",
1609  "truncation": null,
1610  "padding": null,
1611  "added_tokens": [],
1612  "normalizer": null,
1613  "pre_tokenizer": null,
1614  "post_processor": null,
1615  "decoder": null,
1616  "model": {
1617    "type": "WordLevel",
1618    "vocab": {
1619      "<|endoftext|>": 7,
1620      "hello": 11
1621    },
1622    "unk_token": "<|endoftext|>"
1623  }
1624}"#;
1625        tokenizers::Tokenizer::from_bytes(tokenizer_json.as_bytes()).unwrap()
1626    }
1627
1628    #[test]
1629    fn prompt_cache_key_includes_negative_prompt_and_cfg() {
1630        let base = WuerstchenEngine::prompt_cache_key("hello", "", false, false);
1631        let neg = WuerstchenEngine::prompt_cache_key("hello", "bad", false, false);
1632        let prior_cfg = WuerstchenEngine::prompt_cache_key("hello", "", true, false);
1633        let decoder_cfg = WuerstchenEngine::prompt_cache_key("hello", "", false, true);
1634
1635        assert_ne!(base, neg);
1636        assert_ne!(base, prior_cfg);
1637        assert_ne!(base, decoder_cfg);
1638        assert_ne!(prior_cfg, decoder_cfg);
1639    }
1640
1641    #[test]
1642    fn tokenize_uses_clip_pad_token() {
1643        let tokenizer = test_tokenizer();
1644        let clip_config = stable_diffusion::clip::Config::wuerstchen();
1645        let (tokens, tokens_len) =
1646            WuerstchenEngine::tokenize(&tokenizer, "hello", 4, &clip_config, &Device::Cpu).unwrap();
1647        let ids = tokens.squeeze(0).unwrap().to_vec1::<u32>().unwrap();
1648
1649        assert_eq!(tokens_len, 1);
1650        assert_eq!(ids, vec![11, 7, 7, 7]);
1651    }
1652
1653    #[test]
1654    fn tokenize_falls_back_to_pad_token_for_empty_prompt() {
1655        let tokenizer = test_tokenizer();
1656        let clip_config = stable_diffusion::clip::Config::wuerstchen();
1657        let (tokens, tokens_len) =
1658            WuerstchenEngine::tokenize(&tokenizer, "", 3, &clip_config, &Device::Cpu).unwrap();
1659        let ids = tokens.squeeze(0).unwrap().to_vec1::<u32>().unwrap();
1660
1661        assert_eq!(tokens_len, 1);
1662        assert_eq!(ids, vec![7, 7, 7]);
1663    }
1664
1665    #[test]
1666    fn effective_prior_steps_passes_through() {
1667        assert_eq!(WuerstchenEngine::effective_prior_steps(30), 30);
1668        assert_eq!(WuerstchenEngine::effective_prior_steps(60), 60);
1669        assert_eq!(WuerstchenEngine::effective_prior_steps(20), 20);
1670    }
1671
1672    #[test]
1673    fn ddpmw_add_noise_matches_reference_formula() {
1674        let dev = Device::Cpu;
1675        let original = Tensor::from_vec(vec![2.0f32, -1.0], (1, 2), &dev).unwrap();
1676        let noise = Tensor::from_vec(vec![3.0f32, 4.0], (1, 2), &dev).unwrap();
1677        let t = 0.5f64;
1678
1679        let actual = WuerstchenEngine::ddpmw_add_noise(&original, &noise, t)
1680            .unwrap()
1681            .flatten_all()
1682            .unwrap()
1683            .to_vec1::<f32>()
1684            .unwrap();
1685
1686        let s = 0.008f64;
1687        let init_alpha_cumprod = (s / (1.0 + s) * std::f64::consts::PI).cos().powi(2);
1688        let alpha_cumprod = ((t + s) / (1.0 + s) * std::f64::consts::PI * 0.5)
1689            .cos()
1690            .powi(2)
1691            / init_alpha_cumprod;
1692        let alpha_cumprod = alpha_cumprod.clamp(0.0001, 0.9999);
1693        let sqrt_alpha = alpha_cumprod.sqrt() as f32;
1694        let sqrt_one_minus_alpha = (1.0 - alpha_cumprod).sqrt() as f32;
1695        let expected = [
1696            2.0f32 * sqrt_alpha + 3.0 * sqrt_one_minus_alpha,
1697            -sqrt_alpha + 4.0 * sqrt_one_minus_alpha,
1698        ];
1699
1700        for (actual, expected) in actual.iter().zip(expected.iter()) {
1701            assert!((actual - expected).abs() < 1e-5);
1702        }
1703    }
1704
1705    #[test]
1706    fn validate_paths_falls_back_to_prior_clip_when_decoder_clip_missing() {
1707        let dir = temp_test_dir("mold-wuerstchen-validate-ok");
1708        let transformer = touch(&dir, "prior.safetensors");
1709        let decoder = touch(&dir, "decoder.safetensors");
1710        let vae = touch(&dir, "vqgan.safetensors");
1711        let prior_clip_encoder = touch(&dir, "prior-clip.safetensors");
1712        let prior_clip_tokenizer = touch(&dir, "prior-tokenizer.json");
1713
1714        let engine = WuerstchenEngine::new(
1715            "wuerstchen-v2:fp16".to_string(),
1716            wuerstchen_model_paths(
1717                transformer,
1718                Some(decoder.clone()),
1719                vae,
1720                Some(prior_clip_encoder.clone()),
1721                Some(prior_clip_tokenizer.clone()),
1722                None,
1723                None,
1724            ),
1725            LoadStrategy::Sequential,
1726            0,
1727            None,
1728        );
1729
1730        let (
1731            decoder_path,
1732            resolved_prior_clip_encoder,
1733            resolved_prior_clip_tokenizer,
1734            resolved_decoder_clip_encoder,
1735            resolved_decoder_clip_tokenizer,
1736        ) = engine.validate_paths().unwrap();
1737
1738        assert_eq!(decoder_path, decoder);
1739        assert_eq!(resolved_prior_clip_encoder, prior_clip_encoder);
1740        assert_eq!(resolved_prior_clip_tokenizer, prior_clip_tokenizer);
1741        assert_eq!(resolved_decoder_clip_encoder, resolved_prior_clip_encoder);
1742        assert_eq!(
1743            resolved_decoder_clip_tokenizer,
1744            resolved_prior_clip_tokenizer
1745        );
1746
1747        fs::remove_dir_all(dir).ok();
1748    }
1749
1750    #[test]
1751    fn wuerstchen_loads_clip_tokenizers_through_shared_pool() {
1752        let dir = temp_test_dir("mold-wuerstchen-tokenizer-pool");
1753        let tokenizer_path = dir.join("tokenizer.json");
1754        tokenizers::Tokenizer::new(BPE::default())
1755            .save(&tokenizer_path, false)
1756            .unwrap();
1757
1758        let shared_pool = Arc::new(Mutex::new(SharedPool::new()));
1759        let pooled = shared_pool
1760            .lock()
1761            .unwrap()
1762            .load_tokenizer(&tokenizer_path)
1763            .unwrap();
1764
1765        let engine = WuerstchenEngine::new(
1766            "wuerstchen-v2:fp16".to_string(),
1767            wuerstchen_model_paths(
1768                dir.join("prior.safetensors"),
1769                Some(dir.join("decoder.safetensors")),
1770                dir.join("vqgan.safetensors"),
1771                Some(dir.join("prior-clip.safetensors")),
1772                Some(tokenizer_path.clone()),
1773                None,
1774                None,
1775            ),
1776            LoadStrategy::Sequential,
1777            0,
1778            Some(shared_pool),
1779        );
1780
1781        let loaded = engine
1782            .load_clip_tokenizer(&tokenizer_path, "Prior CLIP-G")
1783            .unwrap();
1784
1785        assert!(Arc::ptr_eq(&pooled, &loaded));
1786        fs::remove_dir_all(dir).ok();
1787    }
1788
1789    #[test]
1790    fn wuerstchen_loads_vqgan_tensors_through_shared_pool() {
1791        let dir = temp_test_dir("mold-wuerstchen-vqgan-pool");
1792        let vqgan_path = dir.join("vqgan.safetensors");
1793        let data = [1.0f32, 2.0, 3.0, 4.0];
1794        let bytes: Vec<u8> = data.iter().flat_map(|value| value.to_le_bytes()).collect();
1795        let view = TensorView::new(SafeDtype::F32, vec![2, 2], &bytes).unwrap();
1796        serialize_to_file([("decoder.0.weight", view)], &None, &vqgan_path).unwrap();
1797
1798        let shared_pool = Arc::new(Mutex::new(SharedPool::new()));
1799        let pooled = shared_pool
1800            .lock()
1801            .unwrap()
1802            .load_safetensors_cpu_tensors(std::slice::from_ref(&vqgan_path))
1803            .unwrap()
1804            .unwrap();
1805
1806        let engine = WuerstchenEngine::new(
1807            "wuerstchen-v2:fp16".to_string(),
1808            wuerstchen_model_paths(
1809                dir.join("prior.safetensors"),
1810                Some(dir.join("decoder.safetensors")),
1811                vqgan_path,
1812                Some(dir.join("prior-clip.safetensors")),
1813                Some(dir.join("prior-tokenizer.json")),
1814                None,
1815                None,
1816            ),
1817            LoadStrategy::Sequential,
1818            0,
1819            Some(shared_pool),
1820        );
1821
1822        let loaded = engine.load_vqgan_cpu_tensors().unwrap().unwrap();
1823
1824        assert!(Arc::ptr_eq(&pooled, &loaded));
1825        fs::remove_dir_all(dir).ok();
1826    }
1827
1828    #[test]
1829    fn validate_paths_requires_decoder_and_existing_files() {
1830        let dir = temp_test_dir("mold-wuerstchen-validate-missing");
1831        let transformer = touch(&dir, "prior.safetensors");
1832        let vae = touch(&dir, "vqgan.safetensors");
1833        let prior_clip_encoder = touch(&dir, "prior-clip.safetensors");
1834        let prior_clip_tokenizer = touch(&dir, "prior-tokenizer.json");
1835
1836        let missing_decoder_engine = WuerstchenEngine::new(
1837            "wuerstchen-v2:fp16".to_string(),
1838            wuerstchen_model_paths(
1839                transformer.clone(),
1840                None,
1841                vae.clone(),
1842                Some(prior_clip_encoder.clone()),
1843                Some(prior_clip_tokenizer.clone()),
1844                None,
1845                None,
1846            ),
1847            LoadStrategy::Sequential,
1848            0,
1849            None,
1850        );
1851        let err = missing_decoder_engine.validate_paths().unwrap_err();
1852        assert!(err.to_string().contains("Decoder (Stage B) path required"));
1853
1854        let missing_file_engine = WuerstchenEngine::new(
1855            "wuerstchen-v2:fp16".to_string(),
1856            wuerstchen_model_paths(
1857                transformer,
1858                Some(dir.join("missing-decoder.safetensors")),
1859                vae,
1860                Some(prior_clip_encoder),
1861                Some(prior_clip_tokenizer),
1862                None,
1863                None,
1864            ),
1865            LoadStrategy::Sequential,
1866            0,
1867            None,
1868        );
1869        let err = missing_file_engine.validate_paths().unwrap_err();
1870        assert!(err.to_string().contains("decoder (Stage B) file not found"));
1871
1872        fs::remove_dir_all(dir).ok();
1873    }
1874}