1use anyhow::{bail, Result};
2use candle_core::{DType, Device, Tensor};
3use candle_nn::VarBuilder;
4use candle_transformers::models::stable_diffusion;
5use candle_transformers::models::wuerstchen::ddpm::{DDPMWScheduler, DDPMWSchedulerConfig};
6use candle_transformers::models::wuerstchen::diffnext::WDiffNeXt;
7use candle_transformers::models::wuerstchen::paella_vq::PaellaVQ;
8use candle_transformers::models::wuerstchen::prior::WPrior;
9use mold_core::{GenerateRequest, GenerateResponse, ImageData, ModelPaths};
10use std::collections::HashMap;
11use std::path::Path;
12use std::sync::{Arc, Mutex};
13use std::time::Instant;
14use tokenizers::Tokenizer;
15
16use crate::cache::{
17 clear_cache, get_or_insert_cached_tensor_pair, restore_cached_tensor_pair, CachedTensorPair,
18 LruCache, DEFAULT_PROMPT_CACHE_CAPACITY,
19};
20use crate::device::{check_memory_budget, memory_status_string, preflight_memory_check};
21use crate::engine::{rand_seed, InferenceEngine, LoadStrategy};
22use crate::engine_base::EngineBase;
23use crate::image::{build_output_metadata, encode_image, update_output_metadata_size};
24use crate::img_utils;
25use crate::progress::{ProgressCallback, ProgressEvent};
26
27const PRIOR_C_IN: usize = 16;
29const PRIOR_C: usize = 1536;
30const PRIOR_C_COND: usize = 1280;
31const PRIOR_C_R: usize = 64;
32const PRIOR_DEPTH: usize = 32;
33const PRIOR_NHEAD: usize = 24;
34
35const DECODER_C_IN: usize = 4;
37const DECODER_C_OUT: usize = 4;
38const DECODER_C_R: usize = 64;
39const DECODER_C_COND: usize = 1024;
40const DECODER_CLIP_EMBD: usize = 1024;
41const DECODER_PATCH_SIZE: usize = 2;
42
43const LATENT_DIM_SCALE: f64 = 42.67;
46
47const LATENT_DIM_SCALE_DECODER: f64 = 10.67;
49
50const VQGAN_SCALE: f64 = 0.3764;
53
54struct LoadedWuerstchen {
56 prior: Option<WPrior>,
58 decoder: Option<WDiffNeXt>,
60 vqgan: PaellaVQ,
61 prior_clip: stable_diffusion::clip::ClipTextTransformer,
62 decoder_clip: stable_diffusion::clip::ClipTextTransformer,
63 prior_tokenizer: Arc<Tokenizer>,
64 decoder_tokenizer: Arc<Tokenizer>,
65 device: Device,
66 clip_device: Device,
67 dtype: DType,
68}
69
70pub struct WuerstchenEngine {
74 base: EngineBase<LoadedWuerstchen>,
75 prompt_cache: Mutex<LruCache<String, CachedTensorPair>>,
76 pending_placement: Option<mold_core::types::DevicePlacement>,
77 shared_pool: Option<Arc<Mutex<crate::shared_pool::SharedPool>>>,
78}
79
80impl WuerstchenEngine {
81 fn debug_tensor_stats(name: &str, tensor: &Tensor) {
82 if std::env::var_os("MOLD_WUERSTCHEN_DEBUG").is_none() {
83 return;
84 }
85
86 let stats = || -> Result<String> {
87 let dims = tensor.dims().to_vec();
88 let dtype = tensor.dtype();
89 let flat = tensor
90 .to_device(&Device::Cpu)?
91 .to_dtype(DType::F32)?
92 .flatten_all()?
93 .to_vec1::<f32>()?;
94
95 if flat.is_empty() {
96 return Ok(format!(
97 "[wuerstchen-debug] {name}: shape={dims:?} dtype={dtype:?} <empty>"
98 ));
99 }
100
101 let mut min = f32::INFINITY;
102 let mut max = f32::NEG_INFINITY;
103 let mut sum = 0.0f64;
104 let mut sum_sq = 0.0f64;
105 let mut nan_count = 0usize;
106 let mut inf_count = 0usize;
107 let mut finite_count = 0usize;
108
109 for &value in &flat {
110 if value.is_nan() {
111 nan_count += 1;
112 continue;
113 }
114 if value.is_infinite() {
115 inf_count += 1;
116 continue;
117 }
118 min = min.min(value);
119 max = max.max(value);
120 let value = value as f64;
121 sum += value;
122 sum_sq += value * value;
123 finite_count += 1;
124 }
125
126 let (mean, std) = if finite_count > 0 {
127 let mean = sum / finite_count as f64;
128 let variance = (sum_sq / finite_count as f64) - (mean * mean);
129 (mean, variance.max(0.0).sqrt())
130 } else {
131 (f64::NAN, f64::NAN)
132 };
133 let checksum16: f64 = flat.iter().take(16).map(|&v| v as f64).sum();
134
135 Ok(format!(
136 "[wuerstchen-debug] {name}: shape={dims:?} dtype={dtype:?} min={min:.4} max={max:.4} mean={mean:.4} std={std:.4} nan={nan_count} inf={inf_count} checksum16={checksum16:.4}"
137 ))
138 };
139
140 match stats() {
141 Ok(message) => eprintln!("{message}"),
142 Err(err) => eprintln!("[wuerstchen-debug] {name}: <failed: {err}>"),
143 }
144 }
145
146 fn prompt_cache_key(
147 prompt: &str,
148 negative_prompt: &str,
149 use_prior_cfg: bool,
150 use_decoder_cfg: bool,
151 ) -> String {
152 format!(
153 "{prompt}\u{1f}{negative_prompt}\u{1f}prior_cfg={use_prior_cfg}\u{1f}decoder_cfg={use_decoder_cfg}"
154 )
155 }
156
157 fn pad_token_id(
158 tokenizer: &tokenizers::Tokenizer,
159 clip_config: &stable_diffusion::clip::Config,
160 ) -> Result<u32> {
161 let vocab = tokenizer.get_vocab(true);
162 let token = clip_config.pad_with.as_deref().unwrap_or("<|endoftext|>");
163 vocab
164 .get(token)
165 .copied()
166 .ok_or_else(|| anyhow::anyhow!("tokenizer missing pad/eos token '{token}'"))
167 }
168
169 fn encode_clip_prompt(
170 clip: &stable_diffusion::clip::ClipTextTransformer,
171 tokenizer: &tokenizers::Tokenizer,
172 clip_config: &stable_diffusion::clip::Config,
173 prompt: &str,
174 device: &Device,
175 dtype: DType,
176 ) -> Result<Tensor> {
177 let (tokens, tokens_len) = Self::tokenize(
178 tokenizer,
179 prompt,
180 clip_config.max_position_embeddings,
181 clip_config,
182 device,
183 )?;
184 Ok(clip
185 .forward_with_mask(&tokens, tokens_len - 1)?
186 .to_dtype(dtype)?)
187 }
188
189 fn decoder_guidance() -> f64 {
190 std::env::var("MOLD_WUERSTCHEN_DECODER_GUIDANCE")
191 .ok()
192 .and_then(|value| value.parse::<f64>().ok())
193 .unwrap_or(0.0)
194 }
195
196 fn effective_prior_steps(requested_steps: usize) -> usize {
197 if requested_steps < 10 {
198 tracing::warn!(
200 steps = requested_steps,
201 "Wuerstchen prior works best with ≥20 steps"
202 );
203 }
204 requested_steps
205 }
206
207 pub fn new(
208 model_name: String,
209 paths: ModelPaths,
210 load_strategy: LoadStrategy,
211 gpu_ordinal: usize,
212 shared_pool: Option<Arc<Mutex<crate::shared_pool::SharedPool>>>,
213 ) -> Self {
214 Self {
215 base: EngineBase::new(model_name, paths, load_strategy, gpu_ordinal),
216 prompt_cache: Mutex::new(LruCache::new(DEFAULT_PROMPT_CACHE_CAPACITY)),
217 pending_placement: None,
218 shared_pool,
219 }
220 }
221
222 fn load_clip_tokenizer(&self, path: &Path, label: &str) -> Result<Arc<Tokenizer>> {
223 if let Some(shared_pool) = &self.shared_pool {
224 return shared_pool.lock().unwrap().load_tokenizer(path);
225 }
226 Tokenizer::from_file(path)
227 .map(Arc::new)
228 .map_err(|e| anyhow::anyhow!("failed to load {label} tokenizer: {e}"))
229 }
230
231 fn load_vqgan_cpu_tensors(&self) -> Result<Option<Arc<HashMap<String, Tensor>>>> {
232 let Some(shared_pool) = &self.shared_pool else {
233 return Ok(None);
234 };
235 shared_pool
236 .lock()
237 .unwrap()
238 .load_safetensors_cpu_tensors(std::slice::from_ref(&self.base.paths.vae))
239 }
240
241 fn load_vqgan_var_builder<'a>(
242 &self,
243 dtype: DType,
244 device: &Device,
245 component: &str,
246 ) -> Result<VarBuilder<'a>> {
247 if let Some(tensors) = self.load_vqgan_cpu_tensors()? {
248 return Ok(crate::encoders::park::varbuilder_from_parked(
249 tensors.as_ref(),
250 dtype,
251 device,
252 ));
253 }
254
255 crate::weight_loader::load_safetensors_with_progress(
256 std::slice::from_ref(&self.base.paths.vae),
257 dtype,
258 device,
259 component,
260 &self.base.progress,
261 )
262 }
263
264 #[allow(clippy::too_many_arguments)]
265 fn encode_prompt_pair_cached(
266 &self,
267 prior_clip: &stable_diffusion::clip::ClipTextTransformer,
268 prior_tokenizer: &tokenizers::Tokenizer,
269 decoder_clip: &stable_diffusion::clip::ClipTextTransformer,
270 decoder_tokenizer: &tokenizers::Tokenizer,
271 prompt: &str,
272 negative_prompt: &str,
273 device: &Device,
274 clip_device: &Device,
275 dtype: DType,
276 prior_guidance: f64,
277 decoder_guidance: f64,
278 ) -> Result<(Tensor, Tensor)> {
279 let use_prior_cfg = prior_guidance > 1.0;
280 let use_decoder_cfg = decoder_guidance > 1.0;
281 let prior_clip_config = stable_diffusion::clip::Config::wuerstchen_prior();
282 let dec_clip_config = stable_diffusion::clip::Config::wuerstchen();
283 let cache_key =
284 Self::prompt_cache_key(prompt, negative_prompt, use_prior_cfg, use_decoder_cfg);
285 let ((prior_text_embeddings, decoder_text_embeddings), cache_hit) =
286 get_or_insert_cached_tensor_pair(&self.prompt_cache, cache_key, device, dtype, || {
287 self.base
288 .progress
289 .stage_start("Encoding prompt (Prior CLIP-G, 1280-dim)");
290 let encode_start = Instant::now();
291 let prior_text_embeddings = Self::encode_clip_prompt(
292 prior_clip,
293 prior_tokenizer,
294 &prior_clip_config,
295 prompt,
296 clip_device,
297 dtype,
298 )?;
299 let prior_text_embeddings = if use_prior_cfg {
300 let prior_negative_embeddings = Self::encode_clip_prompt(
301 prior_clip,
302 prior_tokenizer,
303 &prior_clip_config,
304 negative_prompt,
305 clip_device,
306 dtype,
307 )?;
308 Tensor::cat(&[&prior_text_embeddings, &prior_negative_embeddings], 0)?
309 } else {
310 prior_text_embeddings
311 };
312 self.base.progress.stage_done(
313 "Encoding prompt (Prior CLIP-G, 1280-dim)",
314 encode_start.elapsed(),
315 );
316
317 self.base
318 .progress
319 .stage_start("Encoding prompt (Decoder CLIP, 1024-dim)");
320 let dec_encode_start = Instant::now();
321 let decoder_text_embeddings = Self::encode_clip_prompt(
322 decoder_clip,
323 decoder_tokenizer,
324 &dec_clip_config,
325 prompt,
326 clip_device,
327 dtype,
328 )?;
329 let decoder_text_embeddings = if use_decoder_cfg {
330 let decoder_negative_embeddings = Self::encode_clip_prompt(
331 decoder_clip,
332 decoder_tokenizer,
333 &dec_clip_config,
334 negative_prompt,
335 clip_device,
336 dtype,
337 )?;
338 Tensor::cat(&[&decoder_text_embeddings, &decoder_negative_embeddings], 0)?
339 } else {
340 decoder_text_embeddings
341 };
342 self.base.progress.stage_done(
343 "Encoding prompt (Decoder CLIP, 1024-dim)",
344 dec_encode_start.elapsed(),
345 );
346 let prior_text_embeddings = prior_text_embeddings.to_device(device)?;
347 let decoder_text_embeddings = decoder_text_embeddings.to_device(device)?;
348 Ok((prior_text_embeddings, decoder_text_embeddings))
349 })?;
350 if cache_hit {
351 self.base.progress.cache_hit("prompt conditioning");
352 }
353 Ok((prior_text_embeddings, decoder_text_embeddings))
354 }
355
356 fn validate_paths(
359 &self,
360 ) -> Result<(
361 std::path::PathBuf,
362 std::path::PathBuf,
363 std::path::PathBuf,
364 std::path::PathBuf,
365 std::path::PathBuf,
366 )> {
367 let decoder = self
368 .base
369 .paths
370 .decoder
371 .as_ref()
372 .ok_or_else(|| anyhow::anyhow!("Decoder (Stage B) path required for Wuerstchen"))?
373 .clone();
374 let prior_clip_encoder = self
376 .base
377 .paths
378 .clip_encoder_2
379 .as_ref()
380 .ok_or_else(|| anyhow::anyhow!("Prior CLIP-G encoder path required for Wuerstchen"))?
381 .clone();
382 let prior_clip_tokenizer = self
383 .base
384 .paths
385 .clip_tokenizer_2
386 .as_ref()
387 .ok_or_else(|| anyhow::anyhow!("Prior CLIP-G tokenizer path required for Wuerstchen"))?
388 .clone();
389 let decoder_clip_encoder = self.base.paths.clip_encoder.clone().unwrap_or_else(|| {
393 tracing::warn!(
394 "Decoder CLIP encoder path not configured — falling back to Prior CLIP. \
395 Run `mold rm wuerstchen-v2:fp16 && mold pull wuerstchen-v2:fp16` to fix."
396 );
397 prior_clip_encoder.clone()
398 });
399 let decoder_clip_tokenizer = self
400 .base
401 .paths
402 .clip_tokenizer
403 .clone()
404 .unwrap_or_else(|| prior_clip_tokenizer.clone());
405
406 for (label, path) in [
407 ("prior (Stage C)", &self.base.paths.transformer),
408 ("decoder (Stage B)", &decoder),
409 ("vqgan (Stage A)", &self.base.paths.vae),
410 ("prior clip_encoder", &prior_clip_encoder),
411 ("prior clip_tokenizer", &prior_clip_tokenizer),
412 ("decoder clip_encoder", &decoder_clip_encoder),
413 ("decoder clip_tokenizer", &decoder_clip_tokenizer),
414 ] {
415 if !path.exists() {
416 bail!("{label} file not found: {}", path.display());
417 }
418 }
419
420 Ok((
421 decoder,
422 prior_clip_encoder,
423 prior_clip_tokenizer,
424 decoder_clip_encoder,
425 decoder_clip_tokenizer,
426 ))
427 }
428
429 fn reload_models_if_needed(&mut self) -> Result<()> {
431 let needs_reload = self
432 .base
433 .loaded
434 .as_ref()
435 .map(|l| l.prior.is_none() || l.decoder.is_none())
436 .unwrap_or(false);
437
438 if needs_reload {
439 let (decoder_path, _, _, _, _) = self.validate_paths()?;
440 let loaded = self.base.loaded.as_ref().unwrap();
441 let device = loaded.device.clone();
442 let dtype = loaded.dtype;
443 let _ = loaded;
444
445 self.base.progress.stage_start("Reloading Prior (Stage C)");
446 let reload_start = Instant::now();
447 let prior_vb = crate::weight_loader::load_safetensors_with_progress(
448 &[&self.base.paths.transformer],
449 dtype,
450 &device,
451 "Wuerstchen Prior",
452 &self.base.progress,
453 )?;
454 let prior = WPrior::new(
455 PRIOR_C_IN,
456 PRIOR_C,
457 PRIOR_C_COND,
458 PRIOR_C_R,
459 PRIOR_DEPTH,
460 PRIOR_NHEAD,
461 false,
462 prior_vb,
463 )?;
464 self.base
465 .progress
466 .stage_done("Reloading Prior (Stage C)", reload_start.elapsed());
467
468 self.base
469 .progress
470 .stage_start("Reloading Decoder (Stage B)");
471 let reload_start = Instant::now();
472 let decoder_vb = crate::weight_loader::load_safetensors_with_progress(
473 &[&decoder_path],
474 DType::F32,
475 &device,
476 "Wuerstchen Decoder",
477 &self.base.progress,
478 )?;
479 let decoder = WDiffNeXt::new(
480 DECODER_C_IN,
481 DECODER_C_OUT,
482 DECODER_C_R,
483 DECODER_C_COND,
484 DECODER_CLIP_EMBD,
485 DECODER_PATCH_SIZE,
486 false,
487 decoder_vb,
488 )?;
489 self.base
490 .progress
491 .stage_done("Reloading Decoder (Stage B)", reload_start.elapsed());
492
493 let loaded = self.base.loaded.as_mut().unwrap();
494 loaded.prior = Some(prior);
495 loaded.decoder = Some(decoder);
496 }
497 Ok(())
498 }
499
500 pub fn load(&mut self) -> Result<()> {
502 if self.base.loaded.is_some() {
503 return Ok(());
504 }
505
506 if self.base.load_strategy == LoadStrategy::Sequential {
507 return Ok(());
508 }
509
510 let (decoder_path, prior_clip_path, prior_clip_tok_path, dec_clip_path, dec_clip_tok_path) =
511 self.validate_paths()?;
512
513 tracing::info!(model = %self.base.model_name, "loading Wuerstchen model components...");
514
515 let device = crate::device::create_device(self.base.gpu_ordinal, &self.base.progress)?;
516 let dtype = if device.is_cpu() {
520 DType::F32
521 } else {
522 DType::F16
523 };
524
525 self.base.progress.stage_start("Loading Prior (Stage C)");
527 let prior_start = Instant::now();
528 let prior_vb = crate::weight_loader::load_safetensors_with_progress(
529 &[&self.base.paths.transformer],
530 dtype,
531 &device,
532 "Wuerstchen Prior",
533 &self.base.progress,
534 )?;
535 let prior = WPrior::new(
536 PRIOR_C_IN,
537 PRIOR_C,
538 PRIOR_C_COND,
539 PRIOR_C_R,
540 PRIOR_DEPTH,
541 PRIOR_NHEAD,
542 false,
543 prior_vb,
544 )?;
545 self.base
546 .progress
547 .stage_done("Loading Prior (Stage C)", prior_start.elapsed());
548
549 self.base.progress.stage_start("Loading Decoder (Stage B)");
552 let decoder_start = Instant::now();
553 let decoder_vb = crate::weight_loader::load_safetensors_with_progress(
554 &[&decoder_path],
555 DType::F32,
556 &device,
557 "Wuerstchen Decoder",
558 &self.base.progress,
559 )?;
560 let decoder = WDiffNeXt::new(
561 DECODER_C_IN,
562 DECODER_C_OUT,
563 DECODER_C_R,
564 DECODER_C_COND,
565 DECODER_CLIP_EMBD,
566 DECODER_PATCH_SIZE,
567 false,
568 decoder_vb,
569 )?;
570 self.base
571 .progress
572 .stage_done("Loading Decoder (Stage B)", decoder_start.elapsed());
573
574 self.base.progress.stage_start("Loading VQ-GAN (Stage A)");
576 let vqgan_start = Instant::now();
577 let vqgan_vb = self.load_vqgan_var_builder(DType::F32, &device, "VQ-GAN")?;
578 let vqgan = PaellaVQ::new(vqgan_vb)?;
579 self.base
580 .progress
581 .stage_done("Loading VQ-GAN (Stage A)", vqgan_start.elapsed());
582
583 let tier1 = self
585 .pending_placement
586 .as_ref()
587 .map(|p| p.text_encoders)
588 .unwrap_or_default();
589 let clip_device = crate::device::resolve_device(Some(tier1), || Ok(device.clone()))?;
590
591 self.base
593 .progress
594 .stage_start("Loading Prior CLIP-G encoder (1280-dim)");
595 let prior_clip_start = Instant::now();
596 let prior_clip_config = stable_diffusion::clip::Config::wuerstchen_prior();
597 let prior_clip = stable_diffusion::build_clip_transformer(
598 &prior_clip_config,
599 &prior_clip_path,
600 &clip_device,
601 DType::F32,
602 )?;
603 self.base.progress.stage_done(
604 "Loading Prior CLIP-G encoder (1280-dim)",
605 prior_clip_start.elapsed(),
606 );
607
608 self.base
610 .progress
611 .stage_start("Loading Decoder CLIP encoder (1024-dim)");
612 let dec_clip_start = Instant::now();
613 let dec_clip_config = stable_diffusion::clip::Config::wuerstchen();
614 let decoder_clip = stable_diffusion::build_clip_transformer(
615 &dec_clip_config,
616 &dec_clip_path,
617 &clip_device,
618 DType::F32,
619 )?;
620 self.base.progress.stage_done(
621 "Loading Decoder CLIP encoder (1024-dim)",
622 dec_clip_start.elapsed(),
623 );
624
625 let prior_tokenizer = self.load_clip_tokenizer(&prior_clip_tok_path, "Prior CLIP-G")?;
627 let decoder_tokenizer = self.load_clip_tokenizer(&dec_clip_tok_path, "Decoder CLIP")?;
628
629 self.base.loaded = Some(LoadedWuerstchen {
630 prior: Some(prior),
631 decoder: Some(decoder),
632 vqgan,
633 prior_clip,
634 decoder_clip,
635 prior_tokenizer,
636 decoder_tokenizer,
637 device,
638 clip_device,
639 dtype,
640 });
641
642 tracing::info!(model = %self.base.model_name, "all Wuerstchen components loaded successfully");
643 Ok(())
644 }
645
646 fn tokenize(
650 tokenizer: &tokenizers::Tokenizer,
651 prompt: &str,
652 max_len: usize,
653 clip_config: &stable_diffusion::clip::Config,
654 device: &Device,
655 ) -> Result<(Tensor, usize)> {
656 let pad_id = Self::pad_token_id(tokenizer, clip_config)?;
657 let encoding = tokenizer
658 .encode(prompt, true)
659 .map_err(|e| anyhow::anyhow!("tokenization failed: {e}"))?;
660 let mut ids = encoding.get_ids().to_vec();
661 ids.truncate(max_len);
662 if ids.is_empty() {
663 ids.push(pad_id);
664 }
665 let tokens_len = ids.len();
666 while ids.len() < max_len {
667 ids.push(pad_id);
668 }
669 Ok((
670 Tensor::new(ids.as_slice(), device)?.unsqueeze(0)?,
671 tokens_len,
672 ))
673 }
674
675 #[allow(clippy::too_many_arguments)]
677 fn denoise_prior(
678 &self,
679 prior: &WPrior,
680 text_embeddings: &Tensor,
681 latents: &mut Tensor,
682 _base_seed: u64,
684 steps: usize,
685 guidance: f64,
686 device: &Device,
687 dtype: DType,
688 ) -> Result<()> {
689 let use_cfg = guidance > 1.0;
690 let scheduler = DDPMWScheduler::new(steps, DDPMWSchedulerConfig::default())?;
691 let timesteps = scheduler.timesteps().to_vec();
692
693 let label = format!("Stage C Prior ({} steps)", timesteps.len() - 1);
694 self.base.progress.stage_start(&label);
695 let start = Instant::now();
696
697 for (step_idx, &t) in timesteps.iter().enumerate() {
698 if step_idx + 1 >= timesteps.len() {
699 break; }
701 let step_start = Instant::now();
702
703 let noise_pred = if use_cfg {
704 let latent_input = Tensor::cat(&[&*latents, &*latents], 0)?;
707 let r = (Tensor::ones(2, dtype, device)? * t)?;
708 let pred = prior.forward(&latent_input, &r, text_embeddings)?;
709 let chunks = pred.chunk(2, 0)?;
710 let (pred_text, pred_uncond) = (&chunks[0], &chunks[1]);
711 (pred_uncond + ((pred_text - pred_uncond)? * guidance)?)?
712 } else {
713 let r = (Tensor::ones(1, dtype, device)? * t)?;
714 prior.forward(&*latents, &r, text_embeddings)?
715 };
716
717 *latents = scheduler.step(&noise_pred, t, &*latents)?;
718
719 self.base.progress.emit(ProgressEvent::DenoiseStep {
720 step: step_idx + 1,
721 total: timesteps.len() - 1,
722 elapsed: step_start.elapsed(),
723 });
724 }
725
726 self.base.progress.stage_done(&label, start.elapsed());
727 Ok(())
728 }
729
730 #[allow(clippy::too_many_arguments)]
739 fn denoise_decoder(
740 &self,
741 decoder: &WDiffNeXt,
742 image_embeddings: &Tensor,
743 text_embeddings: &Tensor,
744 latents: &mut Tensor,
745 _base_seed: u64,
747 steps: usize,
748 start_step: usize,
749 guidance: f64,
750 inpaint_ctx: Option<&img_utils::InpaintContext>,
751 device: &Device,
752 dtype: DType,
753 ) -> Result<()> {
754 let use_cfg = guidance > 1.0;
755 let scheduler = DDPMWScheduler::new(steps, DDPMWSchedulerConfig::default())?;
756 let timesteps = scheduler.timesteps().to_vec();
757 let active_timesteps = ×teps[start_step..timesteps.len() - 1];
759
760 let label = format!("Stage B Decoder ({} steps)", active_timesteps.len());
761 self.base.progress.stage_start(&label);
762 let start = Instant::now();
763
764 for (step_idx, &t) in active_timesteps.iter().enumerate() {
765 let step_start = Instant::now();
766
767 let noise_pred = if use_cfg {
768 let latent_input = Tensor::cat(&[&*latents, &*latents], 0)?;
769 let r = (Tensor::ones(2, dtype, device)? * t)?;
770 let effnet_input = Tensor::cat(
771 &[image_embeddings, &Tensor::zeros_like(image_embeddings)?],
772 0,
773 )?;
774 let pred =
775 decoder.forward(&latent_input, &r, &effnet_input, Some(text_embeddings))?;
776 let chunks = pred.chunk(2, 0)?;
777 let (pred_text, pred_uncond) = (&chunks[0], &chunks[1]);
778 (pred_uncond + ((pred_text - pred_uncond)? * guidance)?)?
779 } else {
780 let r = (Tensor::ones(1, dtype, device)? * t)?;
781 decoder.forward(&*latents, &r, image_embeddings, Some(text_embeddings))?
782 };
783
784 *latents = scheduler.step(&noise_pred, t, &*latents)?;
785
786 if let Some(ctx) = inpaint_ctx {
788 let noised_original = Self::ddpmw_add_noise(&ctx.original_latents, &ctx.noise, t)?;
789 *latents = crate::img2img::blend_inpaint_latents(&*latents, ctx, &noised_original)?;
790 }
791
792 self.base.progress.emit(ProgressEvent::DenoiseStep {
793 step: step_idx + 1,
794 total: active_timesteps.len(),
795 elapsed: step_start.elapsed(),
796 });
797 }
798
799 self.base.progress.stage_done(&label, start.elapsed());
800 Ok(())
801 }
802
803 fn ddpmw_add_noise(original: &Tensor, noise: &Tensor, t: f64) -> Result<Tensor> {
809 let s = 0.008f64;
811 let init_alpha_cumprod = (s / (1.0 + s) * std::f64::consts::PI).cos().powi(2);
812 let alpha_cumprod = ((t + s) / (1.0 + s) * std::f64::consts::PI * 0.5)
813 .cos()
814 .powi(2)
815 / init_alpha_cumprod;
816 let alpha_cumprod = alpha_cumprod.clamp(0.0001, 0.9999);
817
818 let sqrt_alpha = alpha_cumprod.sqrt();
819 let sqrt_one_minus_alpha = (1.0 - alpha_cumprod).sqrt();
820
821 let noised = ((original * sqrt_alpha)? + (noise * sqrt_one_minus_alpha)?)?;
822 Ok(noised)
823 }
824
825 #[allow(clippy::too_many_arguments)]
828 fn prepare_img2img_latents(
829 &self,
830 vqgan: &PaellaVQ,
831 source_bytes: &[u8],
832 width: u32,
833 height: u32,
834 strength: f64,
835 decoder_steps: usize,
836 seed: u64,
837 device: &Device,
838 ) -> Result<(Tensor, usize, Tensor, Tensor)> {
839 self.base
840 .progress
841 .stage_start("Encoding source image (VQ-GAN)");
842 let encode_start = Instant::now();
843
844 let source_tensor = img_utils::decode_source_image(
846 source_bytes,
847 width,
848 height,
849 img_utils::NormalizeRange::ZeroToOne,
850 device,
851 DType::F32,
852 )?;
853
854 let encoded = vqgan.encode(&source_tensor)?;
855 let encoded = (&encoded / VQGAN_SCALE)?;
857
858 self.base
859 .progress
860 .stage_done("Encoding source image (VQ-GAN)", encode_start.elapsed());
861
862 let start_step = crate::img2img::img2img_start_index(decoder_steps, strength);
863
864 let noise = crate::engine::seeded_randn(seed, encoded.dims(), device, DType::F32)?;
866
867 let scheduler = DDPMWScheduler::new(decoder_steps, DDPMWSchedulerConfig::default())?;
869 let timesteps = scheduler.timesteps().to_vec();
870
871 let noised = if start_step < timesteps.len() - 1 {
873 Self::ddpmw_add_noise(&encoded, &noise, timesteps[start_step])?
874 } else {
875 encoded.clone()
876 };
877
878 tracing::info!(
879 start_step,
880 total_steps = decoder_steps,
881 strength,
882 "img2img: starting decoder from step {start_step}"
883 );
884
885 Ok((noised, start_step, encoded, noise))
886 }
887
888 fn generate_sequential(&mut self, req: &GenerateRequest) -> Result<GenerateResponse> {
890 let (decoder_path, prior_clip_path, prior_clip_tok_path, dec_clip_path, dec_clip_tok_path) =
891 self.validate_paths()?;
892
893 if let Some(warning) = check_memory_budget(&self.base.paths, LoadStrategy::Sequential) {
894 self.base.progress.info(&warning);
895 }
896
897 let device = crate::device::create_device(self.base.gpu_ordinal, &self.base.progress)?;
898 let dtype = if device.is_cpu() {
901 DType::F32
902 } else {
903 DType::F16
904 };
905
906 let start = Instant::now();
907 let seed = req.seed.unwrap_or_else(rand_seed);
908 let width = req.width as usize;
909 let height = req.height as usize;
910 let prior_guidance = req.guidance;
911 let decoder_guidance = Self::decoder_guidance();
912 let negative_prompt = req.negative_prompt.as_deref().unwrap_or("");
913 let prior_steps = Self::effective_prior_steps(req.steps as usize);
914 let decoder_steps = 12;
915
916 tracing::info!(
917 prompt = %req.prompt,
918 seed, width, height,
919 prior_steps,
920 decoder_steps,
921 prior_guidance,
922 decoder_guidance,
923 "starting sequential Wuerstchen generation"
924 );
925
926 self.base
927 .progress
928 .info("Using sequential loading (load-use-drop) to minimize peak memory");
929
930 let use_prior_cfg = prior_guidance > 1.0;
932 let use_decoder_cfg = decoder_guidance > 1.0;
933 let cache_key =
934 Self::prompt_cache_key(&req.prompt, negative_prompt, use_prior_cfg, use_decoder_cfg);
935 let (prior_text_embeddings, decoder_text_embeddings) =
936 if let Some((prior_emb, decoder_emb)) =
937 restore_cached_tensor_pair(&self.prompt_cache, &cache_key, &device, dtype)?
938 {
939 self.base.progress.cache_hit("prompt conditioning");
940 (prior_emb, decoder_emb)
941 } else {
942 if let Some(status) = memory_status_string() {
943 self.base.progress.info(&status);
944 }
945
946 let prior_tokenizer =
947 self.load_clip_tokenizer(&prior_clip_tok_path, "Prior CLIP-G")?;
948
949 let tier1 = self
950 .pending_placement
951 .as_ref()
952 .map(|p| p.text_encoders)
953 .unwrap_or_default();
954 let clip_device =
955 crate::device::resolve_device(Some(tier1), || Ok(device.clone()))?;
956
957 self.base
958 .progress
959 .stage_start("Loading Prior CLIP-G encoder (1280-dim)");
960 let clip_start = Instant::now();
961 let prior_clip_config = stable_diffusion::clip::Config::wuerstchen_prior();
962 let prior_clip = stable_diffusion::build_clip_transformer(
963 &prior_clip_config,
964 &prior_clip_path,
965 &clip_device,
966 DType::F32,
967 )?;
968 self.base.progress.stage_done(
969 "Loading Prior CLIP-G encoder (1280-dim)",
970 clip_start.elapsed(),
971 );
972
973 let decoder_tokenizer =
974 self.load_clip_tokenizer(&dec_clip_tok_path, "Decoder CLIP")?;
975
976 self.base
977 .progress
978 .stage_start("Loading Decoder CLIP encoder (1024-dim)");
979 let dec_clip_start = Instant::now();
980 let dec_clip_config = stable_diffusion::clip::Config::wuerstchen();
981 let decoder_clip = stable_diffusion::build_clip_transformer(
982 &dec_clip_config,
983 &dec_clip_path,
984 &clip_device,
985 DType::F32,
986 )?;
987 self.base.progress.stage_done(
988 "Loading Decoder CLIP encoder (1024-dim)",
989 dec_clip_start.elapsed(),
990 );
991
992 let (prior_emb, decoder_emb) = self.encode_prompt_pair_cached(
993 &prior_clip,
994 &prior_tokenizer,
995 &decoder_clip,
996 &decoder_tokenizer,
997 &req.prompt,
998 negative_prompt,
999 &device,
1000 &clip_device,
1001 dtype,
1002 prior_guidance,
1003 decoder_guidance,
1004 )?;
1005
1006 drop(prior_clip);
1007 drop(prior_tokenizer);
1008 self.base.progress.info("Freed Prior CLIP-G encoder");
1009
1010 (prior_emb, decoder_emb)
1011 };
1012 Self::debug_tensor_stats("prior_text_embeddings", &prior_text_embeddings);
1013 Self::debug_tensor_stats("decoder_text_embeddings", &decoder_text_embeddings);
1014 tracing::info!("CLIP encoders processed (sequential mode)");
1015
1016 let is_img2img = req.source_image.is_some();
1017
1018 let (image_embeddings, mut decoder_latents, decoder_start_step, inpaint_ctx) =
1020 if let Some(ref source_bytes) = req.source_image {
1021 self.base
1022 .progress
1023 .info("img2img mode: skipping Prior, encoding source via VQ-GAN");
1024
1025 self.base.progress.stage_start("Loading VQ-GAN (Stage A)");
1027 let vqgan_start = Instant::now();
1028 let vqgan_vb = self.load_vqgan_var_builder(DType::F32, &device, "VQ-GAN")?;
1029 let vqgan = PaellaVQ::new(vqgan_vb)?;
1030 self.base
1031 .progress
1032 .stage_done("Loading VQ-GAN (Stage A)", vqgan_start.elapsed());
1033
1034 let (noised, start_step, encoded, noise) = self.prepare_img2img_latents(
1035 &vqgan,
1036 source_bytes,
1037 req.width,
1038 req.height,
1039 req.strength,
1040 decoder_steps,
1041 seed,
1042 &device,
1043 )?;
1044
1045 let (_, _, enc_h, enc_w) = encoded.dims4()?;
1046 let inpaint_ctx = crate::img2img::maybe_build_inpaint_context(
1047 req.mask_image.as_deref(),
1048 &encoded,
1049 &noise,
1050 enc_h,
1051 enc_w,
1052 &device,
1053 DType::F32,
1054 )?;
1055
1056 let (_, _, _enc_h, _enc_w) = noised.dims4()?;
1059 let prior_latent_h = (height as f64 / LATENT_DIM_SCALE).ceil() as usize;
1060 let prior_latent_w = (width as f64 / LATENT_DIM_SCALE).ceil() as usize;
1061 let image_embeddings = Tensor::zeros(
1062 (1, PRIOR_C_IN, prior_latent_h, prior_latent_w),
1063 DType::F32,
1064 &device,
1065 )?;
1066
1067 drop(vqgan);
1068 device.synchronize()?;
1069 self.base
1070 .progress
1071 .info("Freed VQ-GAN (will reload for decode)");
1072
1073 Self::debug_tensor_stats("decoder_latents_init", &noised);
1074 (image_embeddings, noised, start_step, inpaint_ctx)
1075 } else {
1076 let prior_size = std::fs::metadata(&self.base.paths.transformer)
1078 .map(|m| m.len())
1079 .unwrap_or(0);
1080 let prior_activation_budget = crate::device::activation_bytes(
1081 req.width,
1082 req.height,
1083 if req.guidance > 1.0 { 2 } else { 1 },
1084 crate::device::dtype_bytes(dtype),
1085 crate::device::ActivationFamily::Wuerstchen,
1086 );
1087 preflight_memory_check("Prior (Stage C)", prior_size, prior_activation_budget)?;
1088 if let Some(status) = memory_status_string() {
1089 self.base.progress.info(&status);
1090 }
1091
1092 self.base.progress.stage_start("Loading Prior (Stage C)");
1093 let prior_start = Instant::now();
1094 let prior_vb = crate::weight_loader::load_safetensors_with_progress(
1095 &[&self.base.paths.transformer],
1096 dtype,
1097 &device,
1098 "Wuerstchen Prior",
1099 &self.base.progress,
1100 )?;
1101 let prior = WPrior::new(
1102 PRIOR_C_IN,
1103 PRIOR_C,
1104 PRIOR_C_COND,
1105 PRIOR_C_R,
1106 PRIOR_DEPTH,
1107 PRIOR_NHEAD,
1108 false,
1109 prior_vb,
1110 )?;
1111 self.base
1112 .progress
1113 .stage_done("Loading Prior (Stage C)", prior_start.elapsed());
1114
1115 let latent_h = (height as f64 / LATENT_DIM_SCALE).ceil() as usize;
1117 let latent_w = (width as f64 / LATENT_DIM_SCALE).ceil() as usize;
1118 device.set_seed(seed)?;
1119 let mut prior_latents =
1120 Tensor::randn(0f32, 1f32, (1, PRIOR_C_IN, latent_h, latent_w), &device)?
1121 .to_dtype(dtype)?;
1122 Self::debug_tensor_stats("prior_latents_init", &prior_latents);
1123
1124 self.denoise_prior(
1125 &prior,
1126 &prior_text_embeddings,
1127 &mut prior_latents,
1128 seed,
1129 prior_steps,
1130 prior_guidance,
1131 &device,
1132 dtype,
1133 )?;
1134
1135 Self::debug_tensor_stats("prior_latents_denoised", &prior_latents);
1137 prior_latents = ((prior_latents * 42.)? - 1.)?;
1138 Self::debug_tensor_stats("image_embeddings", &prior_latents);
1139
1140 drop(prior);
1141 device.synchronize()?;
1142 self.base.progress.info("Freed Prior (Stage C)");
1143
1144 let prior_latents = prior_latents.to_dtype(DType::F32)?;
1146 let stage_b_h = (prior_latents.dim(2)? as f64 * LATENT_DIM_SCALE_DECODER) as usize;
1147 let stage_b_w = (prior_latents.dim(3)? as f64 * LATENT_DIM_SCALE_DECODER) as usize;
1148 device.set_seed(seed.wrapping_add(1))?;
1149 let decoder_latents =
1150 Tensor::randn(0f32, 1f32, (1, 4, stage_b_h, stage_b_w), &device)?;
1151 Self::debug_tensor_stats("decoder_latents_init", &decoder_latents);
1152
1153 (prior_latents, decoder_latents, 0, None)
1154 };
1155 drop(prior_text_embeddings);
1156
1157 let decoder_size = std::fs::metadata(&decoder_path)
1159 .map(|m| m.len())
1160 .unwrap_or(0);
1161 let decoder_activation_budget = crate::device::activation_bytes(
1162 req.width,
1163 req.height,
1164 if req.guidance > 1.0 { 2 } else { 1 },
1165 crate::device::dtype_bytes(DType::F32),
1166 crate::device::ActivationFamily::Wuerstchen,
1167 );
1168 preflight_memory_check("Decoder (Stage B)", decoder_size, decoder_activation_budget)?;
1169 if let Some(status) = memory_status_string() {
1170 self.base.progress.info(&status);
1171 }
1172
1173 self.base.progress.stage_start("Loading Decoder (Stage B)");
1175 let dec_start = Instant::now();
1176 let decoder_vb = crate::weight_loader::load_safetensors_with_progress(
1177 &[&decoder_path],
1178 DType::F32,
1179 &device,
1180 "Wuerstchen Decoder",
1181 &self.base.progress,
1182 )?;
1183 let decoder = WDiffNeXt::new(
1184 DECODER_C_IN,
1185 DECODER_C_OUT,
1186 DECODER_C_R,
1187 DECODER_C_COND,
1188 DECODER_CLIP_EMBD,
1189 DECODER_PATCH_SIZE,
1190 false,
1191 decoder_vb,
1192 )?;
1193 self.base
1194 .progress
1195 .stage_done("Loading Decoder (Stage B)", dec_start.elapsed());
1196
1197 let decoder_text_embeddings = decoder_text_embeddings.to_dtype(DType::F32)?;
1199
1200 self.denoise_decoder(
1201 &decoder,
1202 &image_embeddings,
1203 &decoder_text_embeddings,
1204 &mut decoder_latents,
1205 seed,
1206 decoder_steps,
1207 decoder_start_step,
1208 decoder_guidance,
1209 inpaint_ctx.as_ref(),
1210 &device,
1211 DType::F32,
1212 )?;
1213 Self::debug_tensor_stats("decoder_latents_denoised", &decoder_latents);
1214
1215 drop(decoder);
1216 drop(image_embeddings);
1217 drop(decoder_text_embeddings);
1218 drop(inpaint_ctx);
1219 device.synchronize()?;
1220 self.base.progress.info("Freed Decoder (Stage B)");
1221
1222 let vqgan_load_label = if is_img2img {
1225 "Reloading VQ-GAN (Stage A)"
1226 } else {
1227 "Loading VQ-GAN (Stage A)"
1228 };
1229 self.base.progress.stage_start(vqgan_load_label);
1230 let vqgan_start = Instant::now();
1231 let vqgan_vb = self.load_vqgan_var_builder(DType::F32, &device, "VQ-GAN")?;
1232 let vqgan = PaellaVQ::new(vqgan_vb)?;
1233 self.base
1234 .progress
1235 .stage_done(vqgan_load_label, vqgan_start.elapsed());
1236
1237 self.base.progress.stage_start("VQ-GAN decode");
1238 let decode_start = Instant::now();
1239 Self::debug_tensor_stats("decoder_latents_pre_vq", &decoder_latents);
1240 let img = vqgan.decode(&(&decoder_latents * VQGAN_SCALE)?)?;
1241 Self::debug_tensor_stats("image_pre_postprocess", &img);
1242 let img = img.clamp(0f32, 1f32)?;
1243 Self::debug_tensor_stats("image_postprocess", &img);
1244 let img = (img * 255.)?.to_dtype(DType::U8)?;
1245 let img = img.squeeze(0)?;
1246 self.base
1247 .progress
1248 .stage_done("VQ-GAN decode", decode_start.elapsed());
1249
1250 let (_, actual_h, actual_w) = img.dims3()?;
1253 let mut output_metadata = build_output_metadata(req, seed, None);
1254 update_output_metadata_size(&mut output_metadata, actual_w as u32, actual_h as u32);
1255 let image_bytes = encode_image(
1256 &img,
1257 req.resolved_output_format(),
1258 actual_w as u32,
1259 actual_h as u32,
1260 output_metadata.as_ref(),
1261 )?;
1262
1263 let generation_time_ms = start.elapsed().as_millis() as u64;
1264 tracing::info!(
1265 generation_time_ms,
1266 seed,
1267 "sequential Wuerstchen generation complete"
1268 );
1269
1270 Ok(GenerateResponse {
1271 images: vec![ImageData {
1272 data: image_bytes,
1273 format: req.resolved_output_format(),
1274 width: req.width,
1275 height: req.height,
1276 index: 0,
1277 }],
1278 generation_time_ms,
1279 model: req.model.clone(),
1280 seed_used: seed,
1281 video: None,
1282 gpu: None,
1283 })
1284 }
1285}
1286
1287impl WuerstchenEngine {
1288 fn generate_inner(&mut self, req: &GenerateRequest) -> Result<GenerateResponse> {
1289 if req.scheduler.is_some() {
1290 tracing::warn!("scheduler selection not supported for Wuerstchen, ignoring");
1291 }
1292
1293 if self.base.load_strategy == LoadStrategy::Sequential {
1294 return self.generate_sequential(req);
1295 }
1296
1297 self.reload_models_if_needed()?;
1299
1300 let loaded = self
1301 .base
1302 .loaded
1303 .as_ref()
1304 .ok_or_else(|| anyhow::anyhow!("model not loaded — call load() first"))?;
1305
1306 let start = Instant::now();
1307 let seed = req.seed.unwrap_or_else(rand_seed);
1308 let width = req.width as usize;
1309 let height = req.height as usize;
1310 let prior_guidance = req.guidance;
1311 let decoder_guidance = Self::decoder_guidance();
1312 let negative_prompt = req.negative_prompt.as_deref().unwrap_or("");
1313 let prior_steps = Self::effective_prior_steps(req.steps as usize);
1314 let decoder_steps = 12;
1315
1316 tracing::info!(
1317 prompt = %req.prompt,
1318 seed, width, height,
1319 prior_steps,
1320 decoder_steps,
1321 prior_guidance,
1322 decoder_guidance,
1323 "starting Wuerstchen generation"
1324 );
1325
1326 let (prior_text_embeddings, decoder_text_embeddings) = self.encode_prompt_pair_cached(
1328 &loaded.prior_clip,
1329 &loaded.prior_tokenizer,
1330 &loaded.decoder_clip,
1331 &loaded.decoder_tokenizer,
1332 &req.prompt,
1333 negative_prompt,
1334 &loaded.device,
1335 &loaded.clip_device,
1336 loaded.dtype,
1337 prior_guidance,
1338 decoder_guidance,
1339 )?;
1340 Self::debug_tensor_stats("prior_text_embeddings", &prior_text_embeddings);
1341 Self::debug_tensor_stats("decoder_text_embeddings", &decoder_text_embeddings);
1342
1343 let (image_embeddings, mut decoder_latents, decoder_start_step, inpaint_ctx) =
1345 if let Some(ref source_bytes) = req.source_image {
1346 self.base
1347 .progress
1348 .info("img2img mode: skipping Prior, encoding source via VQ-GAN");
1349
1350 let (noised, start_step, encoded, noise) = self.prepare_img2img_latents(
1351 &loaded.vqgan,
1352 source_bytes,
1353 req.width,
1354 req.height,
1355 req.strength,
1356 decoder_steps,
1357 seed,
1358 &loaded.device,
1359 )?;
1360
1361 let (_, _, enc_h, enc_w) = encoded.dims4()?;
1362 let inpaint_ctx = crate::img2img::maybe_build_inpaint_context(
1363 req.mask_image.as_deref(),
1364 &encoded,
1365 &noise,
1366 enc_h,
1367 enc_w,
1368 &loaded.device,
1369 DType::F32,
1370 )?;
1371
1372 let prior_latent_h = (height as f64 / LATENT_DIM_SCALE).ceil() as usize;
1374 let prior_latent_w = (width as f64 / LATENT_DIM_SCALE).ceil() as usize;
1375 let image_embeddings = Tensor::zeros(
1376 (1, PRIOR_C_IN, prior_latent_h, prior_latent_w),
1377 DType::F32,
1378 &loaded.device,
1379 )?;
1380
1381 Self::debug_tensor_stats("decoder_latents_init", &noised);
1382 (image_embeddings, noised, start_step, inpaint_ctx)
1383 } else {
1384 let latent_h = (height as f64 / LATENT_DIM_SCALE).ceil() as usize;
1386 let latent_w = (width as f64 / LATENT_DIM_SCALE).ceil() as usize;
1387 loaded.device.set_seed(seed)?;
1388 let mut prior_latents = Tensor::randn(
1389 0f32,
1390 1f32,
1391 (1, PRIOR_C_IN, latent_h, latent_w),
1392 &loaded.device,
1393 )?
1394 .to_dtype(loaded.dtype)?;
1395 Self::debug_tensor_stats("prior_latents_init", &prior_latents);
1396
1397 let prior = loaded
1398 .prior
1399 .as_ref()
1400 .ok_or_else(|| anyhow::anyhow!("Prior not loaded"))?;
1401 self.denoise_prior(
1402 prior,
1403 &prior_text_embeddings,
1404 &mut prior_latents,
1405 seed,
1406 prior_steps,
1407 prior_guidance,
1408 &loaded.device,
1409 loaded.dtype,
1410 )?;
1411
1412 Self::debug_tensor_stats("prior_latents_denoised", &prior_latents);
1414 prior_latents = ((prior_latents * 42.)? - 1.)?;
1415 Self::debug_tensor_stats("image_embeddings", &prior_latents);
1416
1417 let prior_latents = prior_latents.to_dtype(DType::F32)?;
1419 let stage_b_h = (prior_latents.dim(2)? as f64 * LATENT_DIM_SCALE_DECODER) as usize;
1420 let stage_b_w = (prior_latents.dim(3)? as f64 * LATENT_DIM_SCALE_DECODER) as usize;
1421 loaded.device.set_seed(seed.wrapping_add(1))?;
1422 let decoder_latents =
1423 Tensor::randn(0f32, 1f32, (1, 4, stage_b_h, stage_b_w), &loaded.device)?;
1424 Self::debug_tensor_stats("decoder_latents_init", &decoder_latents);
1425
1426 (prior_latents, decoder_latents, 0, None)
1427 };
1428
1429 let decoder_text_embeddings = decoder_text_embeddings.to_dtype(DType::F32)?;
1432
1433 let decoder = loaded
1434 .decoder
1435 .as_ref()
1436 .ok_or_else(|| anyhow::anyhow!("Decoder not loaded"))?;
1437 self.denoise_decoder(
1438 decoder,
1439 &image_embeddings,
1440 &decoder_text_embeddings,
1441 &mut decoder_latents,
1442 seed,
1443 decoder_steps,
1444 decoder_start_step,
1445 decoder_guidance,
1446 inpaint_ctx.as_ref(),
1447 &loaded.device,
1448 DType::F32,
1449 )?;
1450 Self::debug_tensor_stats("decoder_latents_denoised", &decoder_latents);
1451
1452 drop(inpaint_ctx);
1454 let _ = loaded;
1455 let loaded = self.base.loaded.as_mut().unwrap();
1456 loaded.prior = None;
1457 loaded.decoder = None;
1458 loaded.device.synchronize()?;
1459 tracing::info!("Prior + Decoder dropped to free VRAM for VQ-GAN decode");
1460 let _ = loaded;
1461 let loaded = self.base.loaded.as_ref().unwrap();
1462
1463 self.base.progress.stage_start("VQ-GAN decode");
1465 let decode_start = Instant::now();
1466 Self::debug_tensor_stats("decoder_latents_pre_vq", &decoder_latents);
1467 let img = loaded.vqgan.decode(&(&decoder_latents * VQGAN_SCALE)?)?;
1468 Self::debug_tensor_stats("image_pre_postprocess", &img);
1469 let img = img.clamp(0f32, 1f32)?;
1470 Self::debug_tensor_stats("image_postprocess", &img);
1471 let img = (img * 255.)?.to_dtype(DType::U8)?;
1472 let img = img.squeeze(0)?;
1473 self.base
1474 .progress
1475 .stage_done("VQ-GAN decode", decode_start.elapsed());
1476
1477 let (_, actual_h, actual_w) = img.dims3()?;
1481 let mut output_metadata = build_output_metadata(req, seed, None);
1482 update_output_metadata_size(&mut output_metadata, actual_w as u32, actual_h as u32);
1483 let image_bytes = encode_image(
1484 &img,
1485 req.resolved_output_format(),
1486 actual_w as u32,
1487 actual_h as u32,
1488 output_metadata.as_ref(),
1489 )?;
1490
1491 let generation_time_ms = start.elapsed().as_millis() as u64;
1492 tracing::info!(generation_time_ms, seed, "Wuerstchen generation complete");
1493
1494 Ok(GenerateResponse {
1495 images: vec![ImageData {
1496 data: image_bytes,
1497 format: req.resolved_output_format(),
1498 width: req.width,
1499 height: req.height,
1500 index: 0,
1501 }],
1502 generation_time_ms,
1503 model: req.model.clone(),
1504 seed_used: seed,
1505 video: None,
1506 gpu: None,
1507 })
1508 }
1509}
1510
1511impl InferenceEngine for WuerstchenEngine {
1512 fn generate(&mut self, req: &GenerateRequest) -> Result<GenerateResponse> {
1513 self.pending_placement = req.placement.clone();
1514 let result = self.generate_inner(req);
1515 self.pending_placement = None;
1516 result
1517 }
1518
1519 fn model_name(&self) -> &str {
1520 self.base.model_name()
1521 }
1522
1523 fn is_loaded(&self) -> bool {
1524 self.base.is_loaded()
1525 }
1526
1527 fn load(&mut self) -> Result<()> {
1528 WuerstchenEngine::load(self)
1529 }
1530
1531 fn unload(&mut self) {
1532 self.base.unload();
1533 clear_cache(&self.prompt_cache);
1534 }
1535
1536 fn set_on_progress(&mut self, callback: ProgressCallback) {
1537 self.base.set_on_progress(callback);
1538 }
1539
1540 fn clear_on_progress(&mut self) {
1541 self.base.clear_on_progress();
1542 }
1543
1544 fn model_paths(&self) -> Option<&mold_core::ModelPaths> {
1545 Some(&self.base.paths)
1546 }
1547}
1548
1549#[cfg(test)]
1550mod tests {
1551 use super::*;
1552 use crate::engine::LoadStrategy;
1553 use crate::shared_pool::SharedPool;
1554 use mold_core::ModelPaths;
1555 use safetensors::tensor::{serialize_to_file, Dtype as SafeDtype, TensorView};
1556 use std::fs;
1557 use std::path::{Path, PathBuf};
1558 use std::sync::{Arc, Mutex};
1559 use std::time::{SystemTime, UNIX_EPOCH};
1560 use tokenizers::models::bpe::BPE;
1561
1562 fn temp_test_dir(prefix: &str) -> PathBuf {
1563 let suffix = SystemTime::now()
1564 .duration_since(UNIX_EPOCH)
1565 .unwrap()
1566 .as_nanos();
1567 let dir = std::env::temp_dir().join(format!("{prefix}-{}-{suffix}", std::process::id()));
1568 fs::create_dir_all(&dir).unwrap();
1569 dir
1570 }
1571
1572 fn touch(dir: &Path, name: &str) -> PathBuf {
1573 let path = dir.join(name);
1574 fs::write(&path, b"test").unwrap();
1575 path
1576 }
1577
1578 fn wuerstchen_model_paths(
1579 transformer: PathBuf,
1580 decoder: Option<PathBuf>,
1581 vae: PathBuf,
1582 prior_clip_encoder: Option<PathBuf>,
1583 prior_clip_tokenizer: Option<PathBuf>,
1584 decoder_clip_encoder: Option<PathBuf>,
1585 decoder_clip_tokenizer: Option<PathBuf>,
1586 ) -> ModelPaths {
1587 ModelPaths {
1588 transformer,
1589 transformer_shards: vec![],
1590 vae,
1591 spatial_upscaler: None,
1592 temporal_upscaler: None,
1593 distilled_lora: None,
1594 t5_encoder: None,
1595 clip_encoder: decoder_clip_encoder,
1596 t5_tokenizer: None,
1597 clip_tokenizer: decoder_clip_tokenizer,
1598 clip_encoder_2: prior_clip_encoder,
1599 clip_tokenizer_2: prior_clip_tokenizer,
1600 text_encoder_files: vec![],
1601 text_tokenizer: None,
1602 decoder,
1603 }
1604 }
1605
1606 fn test_tokenizer() -> tokenizers::Tokenizer {
1607 let tokenizer_json = r#"{
1608 "version": "1.0",
1609 "truncation": null,
1610 "padding": null,
1611 "added_tokens": [],
1612 "normalizer": null,
1613 "pre_tokenizer": null,
1614 "post_processor": null,
1615 "decoder": null,
1616 "model": {
1617 "type": "WordLevel",
1618 "vocab": {
1619 "<|endoftext|>": 7,
1620 "hello": 11
1621 },
1622 "unk_token": "<|endoftext|>"
1623 }
1624}"#;
1625 tokenizers::Tokenizer::from_bytes(tokenizer_json.as_bytes()).unwrap()
1626 }
1627
1628 #[test]
1629 fn prompt_cache_key_includes_negative_prompt_and_cfg() {
1630 let base = WuerstchenEngine::prompt_cache_key("hello", "", false, false);
1631 let neg = WuerstchenEngine::prompt_cache_key("hello", "bad", false, false);
1632 let prior_cfg = WuerstchenEngine::prompt_cache_key("hello", "", true, false);
1633 let decoder_cfg = WuerstchenEngine::prompt_cache_key("hello", "", false, true);
1634
1635 assert_ne!(base, neg);
1636 assert_ne!(base, prior_cfg);
1637 assert_ne!(base, decoder_cfg);
1638 assert_ne!(prior_cfg, decoder_cfg);
1639 }
1640
1641 #[test]
1642 fn tokenize_uses_clip_pad_token() {
1643 let tokenizer = test_tokenizer();
1644 let clip_config = stable_diffusion::clip::Config::wuerstchen();
1645 let (tokens, tokens_len) =
1646 WuerstchenEngine::tokenize(&tokenizer, "hello", 4, &clip_config, &Device::Cpu).unwrap();
1647 let ids = tokens.squeeze(0).unwrap().to_vec1::<u32>().unwrap();
1648
1649 assert_eq!(tokens_len, 1);
1650 assert_eq!(ids, vec![11, 7, 7, 7]);
1651 }
1652
1653 #[test]
1654 fn tokenize_falls_back_to_pad_token_for_empty_prompt() {
1655 let tokenizer = test_tokenizer();
1656 let clip_config = stable_diffusion::clip::Config::wuerstchen();
1657 let (tokens, tokens_len) =
1658 WuerstchenEngine::tokenize(&tokenizer, "", 3, &clip_config, &Device::Cpu).unwrap();
1659 let ids = tokens.squeeze(0).unwrap().to_vec1::<u32>().unwrap();
1660
1661 assert_eq!(tokens_len, 1);
1662 assert_eq!(ids, vec![7, 7, 7]);
1663 }
1664
1665 #[test]
1666 fn effective_prior_steps_passes_through() {
1667 assert_eq!(WuerstchenEngine::effective_prior_steps(30), 30);
1668 assert_eq!(WuerstchenEngine::effective_prior_steps(60), 60);
1669 assert_eq!(WuerstchenEngine::effective_prior_steps(20), 20);
1670 }
1671
1672 #[test]
1673 fn ddpmw_add_noise_matches_reference_formula() {
1674 let dev = Device::Cpu;
1675 let original = Tensor::from_vec(vec![2.0f32, -1.0], (1, 2), &dev).unwrap();
1676 let noise = Tensor::from_vec(vec![3.0f32, 4.0], (1, 2), &dev).unwrap();
1677 let t = 0.5f64;
1678
1679 let actual = WuerstchenEngine::ddpmw_add_noise(&original, &noise, t)
1680 .unwrap()
1681 .flatten_all()
1682 .unwrap()
1683 .to_vec1::<f32>()
1684 .unwrap();
1685
1686 let s = 0.008f64;
1687 let init_alpha_cumprod = (s / (1.0 + s) * std::f64::consts::PI).cos().powi(2);
1688 let alpha_cumprod = ((t + s) / (1.0 + s) * std::f64::consts::PI * 0.5)
1689 .cos()
1690 .powi(2)
1691 / init_alpha_cumprod;
1692 let alpha_cumprod = alpha_cumprod.clamp(0.0001, 0.9999);
1693 let sqrt_alpha = alpha_cumprod.sqrt() as f32;
1694 let sqrt_one_minus_alpha = (1.0 - alpha_cumprod).sqrt() as f32;
1695 let expected = [
1696 2.0f32 * sqrt_alpha + 3.0 * sqrt_one_minus_alpha,
1697 -sqrt_alpha + 4.0 * sqrt_one_minus_alpha,
1698 ];
1699
1700 for (actual, expected) in actual.iter().zip(expected.iter()) {
1701 assert!((actual - expected).abs() < 1e-5);
1702 }
1703 }
1704
1705 #[test]
1706 fn validate_paths_falls_back_to_prior_clip_when_decoder_clip_missing() {
1707 let dir = temp_test_dir("mold-wuerstchen-validate-ok");
1708 let transformer = touch(&dir, "prior.safetensors");
1709 let decoder = touch(&dir, "decoder.safetensors");
1710 let vae = touch(&dir, "vqgan.safetensors");
1711 let prior_clip_encoder = touch(&dir, "prior-clip.safetensors");
1712 let prior_clip_tokenizer = touch(&dir, "prior-tokenizer.json");
1713
1714 let engine = WuerstchenEngine::new(
1715 "wuerstchen-v2:fp16".to_string(),
1716 wuerstchen_model_paths(
1717 transformer,
1718 Some(decoder.clone()),
1719 vae,
1720 Some(prior_clip_encoder.clone()),
1721 Some(prior_clip_tokenizer.clone()),
1722 None,
1723 None,
1724 ),
1725 LoadStrategy::Sequential,
1726 0,
1727 None,
1728 );
1729
1730 let (
1731 decoder_path,
1732 resolved_prior_clip_encoder,
1733 resolved_prior_clip_tokenizer,
1734 resolved_decoder_clip_encoder,
1735 resolved_decoder_clip_tokenizer,
1736 ) = engine.validate_paths().unwrap();
1737
1738 assert_eq!(decoder_path, decoder);
1739 assert_eq!(resolved_prior_clip_encoder, prior_clip_encoder);
1740 assert_eq!(resolved_prior_clip_tokenizer, prior_clip_tokenizer);
1741 assert_eq!(resolved_decoder_clip_encoder, resolved_prior_clip_encoder);
1742 assert_eq!(
1743 resolved_decoder_clip_tokenizer,
1744 resolved_prior_clip_tokenizer
1745 );
1746
1747 fs::remove_dir_all(dir).ok();
1748 }
1749
1750 #[test]
1751 fn wuerstchen_loads_clip_tokenizers_through_shared_pool() {
1752 let dir = temp_test_dir("mold-wuerstchen-tokenizer-pool");
1753 let tokenizer_path = dir.join("tokenizer.json");
1754 tokenizers::Tokenizer::new(BPE::default())
1755 .save(&tokenizer_path, false)
1756 .unwrap();
1757
1758 let shared_pool = Arc::new(Mutex::new(SharedPool::new()));
1759 let pooled = shared_pool
1760 .lock()
1761 .unwrap()
1762 .load_tokenizer(&tokenizer_path)
1763 .unwrap();
1764
1765 let engine = WuerstchenEngine::new(
1766 "wuerstchen-v2:fp16".to_string(),
1767 wuerstchen_model_paths(
1768 dir.join("prior.safetensors"),
1769 Some(dir.join("decoder.safetensors")),
1770 dir.join("vqgan.safetensors"),
1771 Some(dir.join("prior-clip.safetensors")),
1772 Some(tokenizer_path.clone()),
1773 None,
1774 None,
1775 ),
1776 LoadStrategy::Sequential,
1777 0,
1778 Some(shared_pool),
1779 );
1780
1781 let loaded = engine
1782 .load_clip_tokenizer(&tokenizer_path, "Prior CLIP-G")
1783 .unwrap();
1784
1785 assert!(Arc::ptr_eq(&pooled, &loaded));
1786 fs::remove_dir_all(dir).ok();
1787 }
1788
1789 #[test]
1790 fn wuerstchen_loads_vqgan_tensors_through_shared_pool() {
1791 let dir = temp_test_dir("mold-wuerstchen-vqgan-pool");
1792 let vqgan_path = dir.join("vqgan.safetensors");
1793 let data = [1.0f32, 2.0, 3.0, 4.0];
1794 let bytes: Vec<u8> = data.iter().flat_map(|value| value.to_le_bytes()).collect();
1795 let view = TensorView::new(SafeDtype::F32, vec![2, 2], &bytes).unwrap();
1796 serialize_to_file([("decoder.0.weight", view)], &None, &vqgan_path).unwrap();
1797
1798 let shared_pool = Arc::new(Mutex::new(SharedPool::new()));
1799 let pooled = shared_pool
1800 .lock()
1801 .unwrap()
1802 .load_safetensors_cpu_tensors(std::slice::from_ref(&vqgan_path))
1803 .unwrap()
1804 .unwrap();
1805
1806 let engine = WuerstchenEngine::new(
1807 "wuerstchen-v2:fp16".to_string(),
1808 wuerstchen_model_paths(
1809 dir.join("prior.safetensors"),
1810 Some(dir.join("decoder.safetensors")),
1811 vqgan_path,
1812 Some(dir.join("prior-clip.safetensors")),
1813 Some(dir.join("prior-tokenizer.json")),
1814 None,
1815 None,
1816 ),
1817 LoadStrategy::Sequential,
1818 0,
1819 Some(shared_pool),
1820 );
1821
1822 let loaded = engine.load_vqgan_cpu_tensors().unwrap().unwrap();
1823
1824 assert!(Arc::ptr_eq(&pooled, &loaded));
1825 fs::remove_dir_all(dir).ok();
1826 }
1827
1828 #[test]
1829 fn validate_paths_requires_decoder_and_existing_files() {
1830 let dir = temp_test_dir("mold-wuerstchen-validate-missing");
1831 let transformer = touch(&dir, "prior.safetensors");
1832 let vae = touch(&dir, "vqgan.safetensors");
1833 let prior_clip_encoder = touch(&dir, "prior-clip.safetensors");
1834 let prior_clip_tokenizer = touch(&dir, "prior-tokenizer.json");
1835
1836 let missing_decoder_engine = WuerstchenEngine::new(
1837 "wuerstchen-v2:fp16".to_string(),
1838 wuerstchen_model_paths(
1839 transformer.clone(),
1840 None,
1841 vae.clone(),
1842 Some(prior_clip_encoder.clone()),
1843 Some(prior_clip_tokenizer.clone()),
1844 None,
1845 None,
1846 ),
1847 LoadStrategy::Sequential,
1848 0,
1849 None,
1850 );
1851 let err = missing_decoder_engine.validate_paths().unwrap_err();
1852 assert!(err.to_string().contains("Decoder (Stage B) path required"));
1853
1854 let missing_file_engine = WuerstchenEngine::new(
1855 "wuerstchen-v2:fp16".to_string(),
1856 wuerstchen_model_paths(
1857 transformer,
1858 Some(dir.join("missing-decoder.safetensors")),
1859 vae,
1860 Some(prior_clip_encoder),
1861 Some(prior_clip_tokenizer),
1862 None,
1863 None,
1864 ),
1865 LoadStrategy::Sequential,
1866 0,
1867 None,
1868 );
1869 let err = missing_file_engine.validate_paths().unwrap_err();
1870 assert!(err.to_string().contains("decoder (Stage B) file not found"));
1871
1872 fs::remove_dir_all(dir).ok();
1873 }
1874}