1use anyhow::{bail, Result};
18use candle_core::{DType, Device, IndexOp, Tensor, D};
19use candle_transformers::models::z_image::postprocess_image;
20use candle_transformers::quantized_var_builder;
21use mold_core::{fit_to_target_area, GenerateRequest, GenerateResponse, ImageData, ModelPaths};
22use std::collections::HashMap;
23use std::path::Path;
24use std::sync::{Arc, Mutex};
25use std::time::Instant;
26use tokenizers::Tokenizer;
27
28use super::quantized_transformer::QuantizedQwenImageTransformer2DModel;
29use super::sampling::{image_seq_len, QwenImageScheduler};
30use super::transformer::{QwenImageConfig, QwenImageTransformer2DModel};
31use super::vae::QwenImageVae;
32use crate::cache::{
33 clear_cache, prompt_text_key, CachedTensor, LruCache, DEFAULT_PROMPT_CACHE_CAPACITY,
34};
35use crate::device::{
36 effective_device_ref, fits_in_memory, fmt_gb, free_vram_bytes, memory_status_string,
37 preflight_memory_check, qwen2_vram_threshold, should_use_gpu, usable_free_vram_bytes,
38};
39use crate::encoders;
40use crate::engine::{rand_seed, InferenceEngine, LoadStrategy};
41use crate::engine_base::EngineBase;
42use crate::image::{build_output_metadata, encode_image};
43use crate::img_utils;
44use crate::progress::{ProgressCallback, ProgressEvent, ProgressReporter};
45use crate::upscaler::tiling::{upscale_with_tiling, TilingConfig};
46
47const VAE_DECODE_VRAM_THRESHOLD: u64 = 2_500_000_000;
50const QWEN_EMPTY_NEGATIVE_PROMPT: &str = " ";
53const QWEN_NATIVE_WIDTH: usize = 1328;
54const QWEN_NATIVE_HEIGHT: usize = 1328;
55const QWEN_GGUF_NATIVE_CFG_HEADROOM: u64 = 14_000_000_000;
56const QWEN_GGUF_MIN_CFG_HEADROOM: u64 = 3_000_000_000;
57const QWEN_VAE_TILE_SIZES: [u32; 3] = [64, 32, 16];
58const QWEN_IMAGE_EDIT_VAE_AREA: u32 = 1024 * 1024;
59const QWEN_IMAGE_EDIT_SYSTEM_PROMPT: &str = "Describe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.";
60
61const QWEN2_FP16_VRAM_THRESHOLD: u64 = 16_000_000_000;
64const QWEN2_HOT_TE_RESIDENCY_HEADROOM: u64 = 1_000_000_000;
68
69#[derive(Debug, Clone, Copy, PartialEq, Eq)]
70enum Qwen2TextEncoderMode {
71 Auto,
72 Gpu,
73 CpuStage,
74 Cpu,
75}
76
77impl Qwen2TextEncoderMode {
78 fn from_env() -> Self {
79 match std::env::var("MOLD_QWEN2_TEXT_ENCODER_MODE")
80 .unwrap_or_default()
81 .to_ascii_lowercase()
82 .as_str()
83 {
84 "gpu" => Self::Gpu,
85 "cpu-stage" => Self::CpuStage,
86 "cpu_stage" => Self::CpuStage,
87 "cpu" => Self::Cpu,
88 _ => Self::Auto,
89 }
90 }
91}
92
93#[derive(Debug, Clone, Copy, PartialEq, Eq)]
94struct Qwen2TextEncoderPlan {
95 use_gpu: bool,
96 use_cpu_staging: bool,
97}
98
99#[derive(Debug, Clone)]
100struct ResolvedQwen2TextEncoder {
101 paths: Vec<std::path::PathBuf>,
102 vision_paths: Vec<std::path::PathBuf>,
103 is_gguf: bool,
104 variant_label: String,
105 size_bytes: u64,
106 auto_use_gpu: bool,
107}
108
109#[derive(Debug, Clone, Copy, PartialEq, Eq)]
110enum Qwen2TextEncoderUsage {
111 Sequential,
112 Resident,
113}
114
115#[derive(Debug, Clone, Copy, PartialEq, Eq)]
116enum Qwen2TextEncoderPostEncodeAction {
117 KeepGpu,
118 ParkCpu,
119 Drop,
120}
121
122#[derive(Debug, Clone, Copy)]
123struct Qwen2TextEncoderResidencyInput {
124 on_gpu: bool,
125 is_quantized: bool,
126 is_metal: bool,
127 keep_te_ram: bool,
128 prompt_cache_miss: bool,
129 transformer_resident: bool,
130 free_vram_bytes: u64,
131 required_vram_bytes: u64,
132}
133
134#[derive(Debug, Clone, Copy)]
135struct QwenTensorStats {
136 min: f32,
137 max: f32,
138 mean: f32,
139 nan_count: u64,
140 pos_inf_count: u64,
141 neg_inf_count: u64,
142 total: usize,
143}
144
145fn safetensors_is_fp8(path: &Path) -> bool {
148 if path.to_str().map(|s| s.contains("fp8")).unwrap_or(false) {
150 return true;
151 }
152 let Ok(tensors) = (unsafe { candle_core::safetensors::MmapedSafetensors::multi(&[path]) })
154 else {
155 return false;
156 };
157 for key in ["x_embedder.weight", "img_in.weight"] {
158 if let Ok(t) = tensors.load(key, &Device::Cpu) {
159 return t.dtype() == DType::F8E4M3;
160 }
161 }
162 false
163}
164
165fn text_encoder_is_fp8(paths: &[std::path::PathBuf]) -> bool {
169 if paths
171 .iter()
172 .any(|p| p.to_str().map(|s| s.contains("fp8")).unwrap_or(false))
173 {
174 return true;
175 }
176 let Some(first) = paths.first() else {
178 return false;
179 };
180 let Ok(tensors) = (unsafe { candle_core::safetensors::MmapedSafetensors::multi(&[first]) })
181 else {
182 return false;
183 };
184 for key in [
185 "model.embed_tokens.weight",
186 "model.layers.0.self_attn.q_proj.weight",
187 ] {
188 if let Ok(t) = tensors.load(key, &Device::Cpu) {
189 return t.dtype() == DType::F8E4M3;
190 }
191 }
192 false
193}
194
195struct LoadedQwenImage {
197 transformer: Option<QwenImageTransformer>,
199 text_encoder: encoders::qwen2_text::Qwen2TextEncoder,
200 vae: QwenImageVae,
201 vae_path: std::path::PathBuf,
202 transformer_cfg: QwenImageConfig,
203 device: Device,
205 vae_device: Device,
207 dtype: DType,
208}
209
210#[allow(clippy::large_enum_variant)]
211enum QwenImageTransformer {
212 BF16(QwenImageTransformer2DModel),
213 Quantized(QuantizedQwenImageTransformer2DModel),
214 Offloaded(super::offload::OffloadedQwenImageTransformer),
215}
216
217#[derive(Clone)]
218struct CachedPromptConditioning {
219 hidden_states: CachedTensor,
220 valid_len: usize,
221}
222
223impl CachedPromptConditioning {
224 fn from_parts(hidden_states: &Tensor, valid_len: usize) -> Result<Self> {
225 Ok(Self {
226 hidden_states: CachedTensor::from_tensor(hidden_states)?,
227 valid_len,
228 })
229 }
230
231 fn restore(&self, device: &Device, dtype: DType) -> Result<(Tensor, Tensor)> {
232 let hidden_states = self.hidden_states.restore(device, dtype)?;
233 let mut mask = vec![0u8; hidden_states.dim(1)?];
234 for value in &mut mask[..self.valid_len] {
235 *value = 1;
236 }
237 let attention_mask = Tensor::from_vec(mask, (1, hidden_states.dim(1)?), device)?;
238 Ok((hidden_states, attention_mask))
239 }
240}
241
242fn pad_text_conditioning(
243 hidden_states: &Tensor,
244 attention_mask: &Tensor,
245 target_len: usize,
246) -> Result<(Tensor, Tensor)> {
247 let seq_len = hidden_states.dim(1)?;
248 if seq_len == target_len {
249 return Ok((hidden_states.clone(), attention_mask.clone()));
250 }
251 if seq_len > target_len {
252 bail!("cannot shrink text conditioning from {seq_len} to {target_len}");
253 }
254
255 let hidden_dim = hidden_states.dim(2)?;
256 let pad_len = target_len - seq_len;
257 let pad_hs = Tensor::zeros(
258 (hidden_states.dim(0)?, pad_len, hidden_dim),
259 hidden_states.dtype(),
260 hidden_states.device(),
261 )?;
262 let pad_mask = Tensor::zeros(
263 (attention_mask.dim(0)?, pad_len),
264 attention_mask.dtype(),
265 attention_mask.device(),
266 )?;
267
268 Ok((
269 Tensor::cat(&[hidden_states, &pad_hs], 1)?,
270 Tensor::cat(&[attention_mask, &pad_mask], 1)?,
271 ))
272}
273
274fn align_cfg_conditioning(
275 cond_hs: &Tensor,
276 cond_mask: &Tensor,
277 uncond_hs: &Tensor,
278 uncond_mask: &Tensor,
279) -> Result<((Tensor, Tensor), (Tensor, Tensor))> {
280 let target_len = cond_hs.dim(1)?.max(uncond_hs.dim(1)?);
281 let cond = pad_text_conditioning(cond_hs, cond_mask, target_len)?;
282 let uncond = pad_text_conditioning(uncond_hs, uncond_mask, target_len)?;
283 Ok((cond, uncond))
284}
285
286impl QwenImageTransformer {
287 fn supports_cfg_batching(&self) -> bool {
288 match self {
289 Self::Quantized(model) => model.supports_cfg_batching(),
290 _ => true,
291 }
292 }
293
294 fn forward(
295 &self,
296 latents: &Tensor,
297 t: &Tensor,
298 encoder_hidden_states: &Tensor,
299 encoder_attention_mask: &Tensor,
300 ) -> Result<Tensor> {
301 match self {
302 Self::BF16(model) => {
303 Ok(model.forward(latents, t, encoder_hidden_states, encoder_attention_mask)?)
304 }
305 Self::Quantized(model) => {
306 Ok(model.forward(latents, t, encoder_hidden_states, encoder_attention_mask)?)
307 }
308 Self::Offloaded(model) => {
309 model.forward(latents, t, encoder_hidden_states, encoder_attention_mask)
310 }
311 }
312 }
313
314 fn forward_packed(
315 &self,
316 packed_latents: &Tensor,
317 t: &Tensor,
318 encoder_hidden_states: &Tensor,
319 encoder_attention_mask: &Tensor,
320 img_shapes: &[(usize, usize, usize)],
321 ) -> Result<Tensor> {
322 match self {
323 Self::BF16(model) => Ok(model.forward_packed(
324 packed_latents,
325 t,
326 encoder_hidden_states,
327 encoder_attention_mask,
328 img_shapes,
329 )?),
330 Self::Quantized(model) => Ok(model.forward_packed(
331 packed_latents,
332 t,
333 encoder_hidden_states,
334 encoder_attention_mask,
335 img_shapes,
336 )?),
337 Self::Offloaded(model) => model.forward_packed(
338 packed_latents,
339 t,
340 encoder_hidden_states,
341 encoder_attention_mask,
342 img_shapes,
343 ),
344 }
345 }
346}
347
348pub struct QwenImageEngine {
350 base: EngineBase<LoadedQwenImage>,
351 prompt_cache: Mutex<LruCache<String, CachedPromptConditioning>>,
352 offload: bool,
353 pending_placement: Option<mold_core::types::DevicePlacement>,
355 pending_loras: Vec<mold_core::LoraWeight>,
360 #[allow(dead_code)]
368 active_lora_fingerprint: Vec<QwenImageLoraFingerprint>,
369 shared_pool: Option<Arc<Mutex<crate::shared_pool::SharedPool>>>,
370}
371
372#[derive(Clone, PartialEq, Eq, Debug)]
374#[allow(dead_code)]
375struct QwenImageLoraFingerprint {
376 path_hash: u64,
377 scale_bits: u64,
378}
379
380impl QwenImageLoraFingerprint {
381 #[allow(dead_code)]
382 fn from_lora(lora: &mold_core::LoraWeight) -> Self {
383 Self {
384 path_hash: super::lora::lora_path_hash(&lora.path),
385 scale_bits: lora.scale.to_bits(),
386 }
387 }
388}
389
390#[allow(dead_code)]
391fn fingerprint_stack(loras: &[mold_core::LoraWeight]) -> Vec<QwenImageLoraFingerprint> {
392 loras
393 .iter()
394 .map(QwenImageLoraFingerprint::from_lora)
395 .collect()
396}
397
398fn effective_loras(req: &mold_core::GenerateRequest) -> Vec<mold_core::LoraWeight> {
402 const ZERO_SCALE_EPS: f64 = 1e-8;
405
406 let raw: Vec<mold_core::LoraWeight> = if let Some(plural) = &req.loras {
407 if !plural.is_empty() {
408 plural.clone()
409 } else {
410 req.lora.iter().cloned().collect()
411 }
412 } else {
413 req.lora.iter().cloned().collect()
414 };
415
416 raw.into_iter()
417 .filter(|w| {
418 let keep = w.scale.abs() > ZERO_SCALE_EPS;
419 if !keep {
420 tracing::debug!(
421 path = w.path.as_str(),
422 scale = w.scale,
423 "dropping zero-scale LoRA from effective Qwen-Image stack"
424 );
425 }
426 keep
427 })
428 .collect()
429}
430
431impl QwenImageEngine {
432 fn is_edit_family(&self) -> bool {
433 self.base.model_name.starts_with("qwen-image-edit")
434 }
435
436 fn should_preload_text_encoder(&self) -> bool {
437 !self.is_edit_family()
438 }
439
440 fn text_encoder_load_dtype(use_gpu: bool, gpu_dtype: DType) -> DType {
441 if use_gpu {
442 gpu_dtype
443 } else {
444 DType::F32
448 }
449 }
450
451 fn transformer_config(&self) -> QwenImageConfig {
452 if self.is_edit_family() {
453 QwenImageConfig::qwen_image_edit_2511()
454 } else {
455 QwenImageConfig::qwen_image_2512()
456 }
457 }
458
459 fn qwen_image_edit_prompt(prompt: &str, image_count: usize) -> String {
460 let picture_prefix = (0..image_count)
461 .map(|idx| {
462 format!(
463 "Picture {}: <|vision_start|><|image_pad|><|vision_end|>",
464 idx + 1
465 )
466 })
467 .collect::<String>();
468 format!(
469 "<|im_start|>system\n{QWEN_IMAGE_EDIT_SYSTEM_PROMPT}<|im_end|>\n<|im_start|>user\n{picture_prefix}{prompt}<|im_end|>\n<|im_start|>assistant\n"
470 )
471 }
472
473 fn qwen_image_edit_image_dims(image: &[u8], target_area: u32) -> Result<(u32, u32)> {
474 let img = image::load_from_memory(image)?;
475 Ok(fit_to_target_area(
476 img.width().max(1),
477 img.height().max(1),
478 target_area,
479 16,
480 ))
481 }
482
483 fn pack_latents_4d(latents: &Tensor) -> Result<Tensor> {
484 let (batch, channels, height, width) = latents.dims4()?;
485 let height_blocks = height / 2;
486 let width_blocks = width / 2;
487 latents
488 .reshape((batch, channels, height_blocks, 2, width_blocks, 2))?
489 .permute((0, 2, 4, 1, 3, 5))?
490 .reshape((batch, height_blocks * width_blocks, channels * 4))
491 .map_err(Into::into)
492 }
493
494 fn unpack_latents_packed(latents: &Tensor, latent_h: usize, latent_w: usize) -> Result<Tensor> {
495 let batch = latents.dim(0)?;
496 latents
497 .reshape((batch, latent_h / 2, latent_w / 2, 16, 2, 2))?
498 .permute((0, 3, 1, 4, 2, 5))?
499 .reshape((batch, 16, latent_h, latent_w))
500 .map_err(Into::into)
501 }
502
503 fn img2img_source_normalize_range() -> img_utils::NormalizeRange {
504 img_utils::NormalizeRange::MinusOneToOne
505 }
506
507 fn is_oom_error(err: &impl std::fmt::Display) -> bool {
508 let msg = err.to_string();
511 msg.contains("OUT_OF_MEMORY")
512 || msg.contains("out of memory")
513 || msg.contains("cudaErrorMemoryAllocation")
514 }
515
516 fn with_cuda_oom_cpu_fallback<T, FPrimary, FFallback, FOom>(
517 primary: FPrimary,
518 fallback: FFallback,
519 is_cuda: bool,
520 sync_device: &Device,
521 progress: &ProgressReporter,
522 oom_message: &str,
523 is_oom: FOom,
524 ) -> Result<T>
525 where
526 FPrimary: FnOnce() -> Result<T>,
527 FFallback: FnOnce() -> Result<T>,
528 FOom: Fn(&anyhow::Error) -> bool,
529 {
530 match primary() {
531 Ok(value) => Ok(value),
532 Err(err) if is_cuda && is_oom(&err) => {
533 progress.info(oom_message);
534 sync_device.synchronize()?;
535 fallback()
536 }
537 Err(err) => Err(err),
538 }
539 }
540
541 #[allow(clippy::too_many_arguments)]
542 fn with_cuda_tiled_then_cpu_fallback<T, FPrimary, FTiled, FCpu, FOom>(
543 primary: FPrimary,
544 tiled: FTiled,
545 cpu_fallback: FCpu,
546 is_cuda: bool,
547 prefer_tiled: bool,
548 sync_device: &Device,
549 progress: &ProgressReporter,
550 tiled_message: &str,
551 cpu_message: &str,
552 is_oom: FOom,
553 ) -> Result<T>
554 where
555 FPrimary: FnOnce() -> Result<T>,
556 FTiled: FnOnce() -> Result<T>,
557 FCpu: FnOnce() -> Result<T>,
558 FOom: Fn(&anyhow::Error) -> bool,
559 {
560 if is_cuda && prefer_tiled {
561 progress.info("Selecting tiled GPU VAE decode proactively");
562 match tiled() {
563 Ok(value) => return Ok(value),
564 Err(tile_err) if is_oom(&tile_err) => {
565 progress.info(cpu_message);
566 sync_device.synchronize()?;
567 return cpu_fallback();
568 }
569 Err(tile_err) => return Err(tile_err),
570 }
571 }
572
573 match primary() {
574 Ok(value) => Ok(value),
575 Err(err) if is_cuda && is_oom(&err) => {
576 progress.info(tiled_message);
577 sync_device.synchronize()?;
578 match tiled() {
579 Ok(value) => Ok(value),
580 Err(tile_err) if is_oom(&tile_err) => {
581 progress.info(cpu_message);
582 sync_device.synchronize()?;
583 cpu_fallback()
584 }
585 Err(tile_err) => Err(tile_err),
586 }
587 }
588 Err(err) => Err(err),
589 }
590 }
591
592 fn qwen_vae_decode_workspace_bytes(width: u32, height: u32) -> u64 {
593 let pixels = width as u64 * height as u64;
594 pixels.saturating_mul(4).saturating_mul(1024)
599 }
600
601 fn should_proactively_tile_vae_decode(
602 width: u32,
603 height: u32,
604 vae_is_cuda: bool,
605 free_vram_bytes: u64,
606 ) -> bool {
607 if !vae_is_cuda || free_vram_bytes == 0 {
608 return false;
609 }
610 let native_pixels = (QWEN_NATIVE_WIDTH * QWEN_NATIVE_HEIGHT) as u64;
611 let pixels = width as u64 * height as u64;
612 if pixels < native_pixels.saturating_mul(3) / 4 {
613 return false;
614 }
615 let required = VAE_DECODE_VRAM_THRESHOLD
616 .saturating_add(Self::qwen_vae_decode_workspace_bytes(width, height));
617 free_vram_bytes < required
618 }
619
620 fn qwen2_text_encoder_post_encode_action(
621 input: Qwen2TextEncoderResidencyInput,
622 ) -> Qwen2TextEncoderPostEncodeAction {
623 if !input.on_gpu {
624 return Qwen2TextEncoderPostEncodeAction::Drop;
625 }
626 if input.prompt_cache_miss
627 && input.transformer_resident
628 && !input.is_metal
629 && input.free_vram_bytes >= input.required_vram_bytes
630 {
631 return Qwen2TextEncoderPostEncodeAction::KeepGpu;
632 }
633 if input.keep_te_ram && !input.is_metal && !input.is_quantized {
634 return Qwen2TextEncoderPostEncodeAction::ParkCpu;
635 }
636 Qwen2TextEncoderPostEncodeAction::Drop
637 }
638
639 fn qwen2_hot_text_encoder_required_vram(
640 width: u32,
641 height: u32,
642 cfg_batch: u32,
643 dtype: DType,
644 ) -> u64 {
645 crate::device::activation_bytes(
646 width,
647 height,
648 cfg_batch,
649 crate::device::dtype_bytes(dtype),
650 crate::device::ActivationFamily::QwenImageDit,
651 )
652 .saturating_add(VAE_DECODE_VRAM_THRESHOLD)
653 .saturating_add(Self::qwen_vae_decode_workspace_bytes(width, height))
654 .saturating_add(QWEN2_HOT_TE_RESIDENCY_HEADROOM)
655 }
656
657 fn decode_vae_tiled(
658 latents: &Tensor,
659 vae: &QwenImageVae,
660 vae_device: &Device,
661 progress: &ProgressReporter,
662 ) -> Result<Tensor> {
663 for tile_size in QWEN_VAE_TILE_SIZES {
664 let overlap = (tile_size / 4).max(4);
665 progress.info(&format!(
666 "Retrying VAE decode with tiled GPU decode (tile {} overlap {})",
667 tile_size, overlap
668 ));
669 let config = TilingConfig {
670 tile_size,
671 overlap,
672 min_tile_size: 16,
673 };
674 let forward = |tile: &Tensor| {
675 let tile = tile.to_device(vae_device)?.to_dtype(DType::F32)?;
676 vae.decode(&tile).map_err(Into::into)
677 };
678 match upscale_with_tiling(latents, &forward, 8, &config, &Device::Cpu, progress) {
683 Ok(image) => return Ok(image),
684 Err(e) if vae_device.is_cuda() && Self::is_oom_error(&e) => {
685 if let Err(sync_err) = vae_device.synchronize() {
686 tracing::warn!(
687 "failed to synchronize CUDA device after tiled VAE OOM: {sync_err}"
688 );
689 }
690 }
691 Err(e) => return Err(e),
692 }
693 }
694
695 bail!("tiled VAE decode still ran out of memory")
696 }
697
698 fn decode_vae_with_fallback<F>(
699 latents: &Tensor,
700 vae: &QwenImageVae,
701 vae_device: &Device,
702 sync_device: &Device,
703 progress: &ProgressReporter,
704 prefer_tiled: bool,
705 load_cpu_vae: F,
706 ) -> Result<Tensor>
707 where
708 F: FnOnce() -> Result<QwenImageVae>,
709 {
710 let decode_latents = latents.to_device(vae_device)?.to_dtype(DType::F32)?;
711 Self::debug_tensor_stats("latents_pre_vae", &decode_latents);
712 Self::with_cuda_tiled_then_cpu_fallback(
713 || vae.decode(&decode_latents).map_err(Into::into),
714 || Self::decode_vae_tiled(latents, vae, vae_device, progress),
715 || {
716 let cpu_vae = load_cpu_vae()?;
717 let cpu_latents = latents.to_device(&Device::Cpu)?.to_dtype(DType::F32)?;
718 cpu_vae.decode(&cpu_latents).map_err(Into::into)
719 },
720 vae_device.is_cuda(),
721 prefer_tiled,
722 sync_device,
723 progress,
724 "VAE decode OOM on GPU — retrying with tiled GPU decode",
725 "Tiled GPU VAE decode OOM — retrying on CPU",
726 Self::is_oom_error,
727 )
728 }
729
730 #[allow(clippy::too_many_arguments)]
732 fn encode_vae_with_fallback(
733 source_bytes: &[u8],
734 width: u32,
735 height: u32,
736 vae: &QwenImageVae,
737 vae_device: &Device,
738 sync_device: &Device,
739 progress: &ProgressReporter,
740 load_cpu_vae: impl FnOnce() -> Result<QwenImageVae>,
741 ) -> Result<Tensor> {
742 progress.stage_start("Encoding source image (VAE)");
743 let encode_start = Instant::now();
744
745 let source_tensor = img_utils::decode_source_image(
747 source_bytes,
748 width,
749 height,
750 Self::img2img_source_normalize_range(),
751 vae_device,
752 DType::F32,
753 )?;
754
755 let result = Self::with_cuda_oom_cpu_fallback(
756 || vae.encode(&source_tensor).map_err(Into::into),
757 || {
758 let cpu_vae = load_cpu_vae()?;
759 let cpu_source = img_utils::decode_source_image(
760 source_bytes,
761 width,
762 height,
763 Self::img2img_source_normalize_range(),
764 &Device::Cpu,
765 DType::F32,
766 )?;
767 cpu_vae.encode(&cpu_source).map_err(Into::into)
768 },
769 vae_device.is_cuda(),
770 sync_device,
771 progress,
772 "VAE encode OOM on GPU — retrying on CPU",
773 Self::is_oom_error,
774 );
775
776 progress.stage_done("Encoding source image (VAE)", encode_start.elapsed());
777 result
778 }
779
780 fn choose_text_encoder_source(
781 preference: Option<&str>,
782 is_cuda: bool,
783 is_metal: bool,
784 free_vram: u64,
785 bf16_size_bytes: u64,
786 _usage: Qwen2TextEncoderUsage,
787 ) -> Result<ResolvedQwen2TextEncoder> {
788 match preference {
789 Some(tag) if tag != "auto" && tag != "bf16" => {
790 let variant = mold_core::manifest::find_qwen2_vl_variant(tag).ok_or_else(|| {
791 anyhow::anyhow!(
792 "unknown Qwen2.5-VL variant '{}'. Valid: bf16, auto, q8, q6, q5, q4, q3, q2",
793 tag
794 )
795 })?;
796 Ok(ResolvedQwen2TextEncoder {
797 paths: vec![],
798 vision_paths: vec![],
799 is_gguf: true,
800 variant_label: variant.tag.to_string(),
801 size_bytes: variant.size_bytes,
802 auto_use_gpu: should_use_gpu(
803 is_cuda,
804 is_metal,
805 free_vram,
806 qwen2_vram_threshold(variant.size_bytes),
807 ),
808 })
809 }
810 Some("bf16") => Ok(ResolvedQwen2TextEncoder {
811 paths: vec![],
812 vision_paths: vec![],
813 is_gguf: false,
814 variant_label: "bf16".to_string(),
815 size_bytes: bf16_size_bytes,
816 auto_use_gpu: should_use_gpu(
817 is_cuda,
818 is_metal,
819 free_vram,
820 QWEN2_FP16_VRAM_THRESHOLD,
821 ),
822 }),
823 _ if is_metal => {
824 for tag in ["q6", "q4"] {
825 let variant = mold_core::manifest::find_qwen2_vl_variant(tag)
826 .expect("known Metal auto qwen2 variant missing");
827 if fits_in_memory(
828 is_cuda,
829 is_metal,
830 free_vram,
831 qwen2_vram_threshold(variant.size_bytes),
832 ) {
833 return Ok(ResolvedQwen2TextEncoder {
834 paths: vec![],
835 vision_paths: vec![],
836 is_gguf: true,
837 variant_label: variant.tag.to_string(),
838 size_bytes: variant.size_bytes,
839 auto_use_gpu: true,
840 });
841 }
842 }
843 let fallback = mold_core::manifest::find_qwen2_vl_variant("q4")
844 .expect("known Metal fallback qwen2 variant missing");
845 Ok(ResolvedQwen2TextEncoder {
846 paths: vec![],
847 vision_paths: vec![],
848 is_gguf: true,
849 variant_label: fallback.tag.to_string(),
850 size_bytes: fallback.size_bytes,
851 auto_use_gpu: true,
852 })
853 }
854 _ => {
855 let bf16_on_gpu =
856 should_use_gpu(is_cuda, is_metal, free_vram, QWEN2_FP16_VRAM_THRESHOLD);
857 if bf16_on_gpu {
858 return Ok(ResolvedQwen2TextEncoder {
859 paths: vec![],
860 vision_paths: vec![],
861 is_gguf: false,
862 variant_label: "bf16".to_string(),
863 size_bytes: bf16_size_bytes,
864 auto_use_gpu: true,
865 });
866 }
867
868 if is_cuda {
869 let fallback_tag = "q4";
870 let fallback = mold_core::manifest::find_qwen2_vl_variant(fallback_tag)
871 .expect("known CUDA fallback qwen2 variant missing");
872 return Ok(ResolvedQwen2TextEncoder {
873 paths: vec![],
874 vision_paths: vec![],
875 is_gguf: true,
876 variant_label: fallback.tag.to_string(),
877 size_bytes: fallback.size_bytes,
878 auto_use_gpu: fits_in_memory(
879 is_cuda,
880 is_metal,
881 free_vram,
882 qwen2_vram_threshold(fallback.size_bytes),
883 ),
884 });
885 }
886
887 Ok(ResolvedQwen2TextEncoder {
888 paths: vec![],
889 vision_paths: vec![],
890 is_gguf: false,
891 variant_label: "bf16".to_string(),
892 size_bytes: bf16_size_bytes,
893 auto_use_gpu: false,
894 })
895 }
896 }
897 }
898
899 fn tensor_stats(tensor: &Tensor) -> Result<QwenTensorStats> {
900 let t = tensor.to_dtype(DType::F32)?;
901 let values = t.flatten_all()?.to_vec1::<f32>()?;
902 let mut min = f32::INFINITY;
903 let mut max = f32::NEG_INFINITY;
904 let mut sum = 0.0f64;
905 let mut finite_count = 0usize;
906 let mut nan_count = 0u64;
907 let mut pos_inf_count = 0u64;
908 let mut neg_inf_count = 0u64;
909 for value in &values {
910 if value.is_nan() {
911 nan_count += 1;
912 } else if *value == f32::INFINITY {
913 pos_inf_count += 1;
914 } else if *value == f32::NEG_INFINITY {
915 neg_inf_count += 1;
916 } else {
917 min = min.min(*value);
918 max = max.max(*value);
919 sum += *value as f64;
920 finite_count += 1;
921 }
922 }
923 let mean = if finite_count == 0 {
924 f32::NAN
925 } else {
926 (sum / finite_count as f64) as f32
927 };
928 if finite_count == 0 {
929 min = f32::NAN;
930 max = f32::NAN;
931 }
932 Ok(QwenTensorStats {
933 min,
934 max,
935 mean,
936 nan_count,
937 pos_inf_count,
938 neg_inf_count,
939 total: values.len(),
940 })
941 }
942
943 fn format_tensor_stats(name: &str, stats: QwenTensorStats) -> String {
944 format!(
945 "[qwen-debug] {name}: min={:.4} max={:.4} mean={:.4} NaN={}/{} ({:.1}%) +Inf={} -Inf={}",
946 stats.min,
947 stats.max,
948 stats.mean,
949 stats.nan_count,
950 stats.total,
951 stats.nan_count as f64 / stats.total.max(1) as f64 * 100.0,
952 stats.pos_inf_count,
953 stats.neg_inf_count
954 )
955 }
956
957 fn near_black_image_stats(stats: QwenTensorStats) -> bool {
958 if stats.nan_count > 0
959 || stats.pos_inf_count > 0
960 || stats.neg_inf_count > 0
961 || !stats.min.is_finite()
962 || !stats.max.is_finite()
963 || !stats.mean.is_finite()
964 {
965 return false;
966 }
967 let scale = if stats.max <= 1.0 { 1.0 } else { 255.0 };
968 stats.max <= 0.02 * scale && stats.mean <= 0.01 * scale
969 }
970
971 fn validate_qwen_tensor_boundary(name: &str, tensor: &Tensor) -> Result<QwenTensorStats> {
972 let stats = Self::tensor_stats(tensor)?;
973 if stats.nan_count > 0
974 || stats.pos_inf_count > 0
975 || stats.neg_inf_count > 0
976 || !stats.min.is_finite()
977 || !stats.max.is_finite()
978 || !stats.mean.is_finite()
979 {
980 bail!(
981 "Qwen diagnostic boundary '{name}' contains non-finite values: {}",
982 Self::format_tensor_stats(name, stats)
983 );
984 }
985 Ok(stats)
986 }
987
988 fn debug_tensor_stats(name: &str, tensor: &Tensor) {
989 if std::env::var_os("MOLD_QWEN_DEBUG").is_none() {
990 return;
991 }
992 match Self::tensor_stats(tensor) {
993 Ok(stats) => eprintln!("{}", Self::format_tensor_stats(name, stats)),
994 Err(err) => eprintln!("[qwen-debug] {name}: <failed: {err}>"),
995 }
996 }
997
998 pub fn new(
999 model_name: String,
1000 paths: ModelPaths,
1001 load_strategy: LoadStrategy,
1002 gpu_ordinal: usize,
1003 offload: bool,
1004 shared_pool: Option<Arc<Mutex<crate::shared_pool::SharedPool>>>,
1005 ) -> Self {
1006 Self {
1007 base: EngineBase::new(model_name, paths, load_strategy, gpu_ordinal),
1008 prompt_cache: Mutex::new(LruCache::new(DEFAULT_PROMPT_CACHE_CAPACITY)),
1009 offload,
1010 pending_placement: None,
1011 pending_loras: Vec::new(),
1012 active_lora_fingerprint: Vec::new(),
1013 shared_pool,
1014 }
1015 }
1016
1017 fn load_text_tokenizer(&self, tokenizer_path: &Path) -> Result<Arc<Tokenizer>> {
1018 if let Some(shared_pool) = &self.shared_pool {
1019 return shared_pool.lock().unwrap().load_tokenizer(tokenizer_path);
1020 }
1021 Tokenizer::from_file(tokenizer_path)
1022 .map(Arc::new)
1023 .map_err(|e| anyhow::anyhow!("failed to load Qwen2.5 tokenizer: {e}"))
1024 }
1025
1026 fn encode_prompt_cached(
1027 progress: &ProgressReporter,
1028 prompt_cache: &Mutex<LruCache<String, CachedPromptConditioning>>,
1029 text_encoder: &mut encoders::qwen2_text::Qwen2TextEncoder,
1030 prompt: &str,
1031 device: &Device,
1032 dtype: DType,
1033 ) -> Result<(Tensor, Tensor)> {
1034 let cache_key = prompt_text_key(prompt);
1035 if let Some(cached) = prompt_cache
1036 .lock()
1037 .expect("cache poisoned")
1038 .get_cloned(&cache_key)
1039 {
1040 progress.cache_hit("prompt conditioning");
1041 return cached.restore(device, dtype);
1042 }
1043
1044 progress.stage_start("Encoding prompt (Qwen2.5)");
1045 let encode_start = Instant::now();
1046 let (hidden_states, _attention_mask, valid_len) =
1047 text_encoder.encode(prompt, device, dtype)?;
1048 progress.stage_done("Encoding prompt (Qwen2.5)", encode_start.elapsed());
1049
1050 prompt_cache.lock().expect("cache poisoned").insert(
1051 cache_key,
1052 CachedPromptConditioning::from_parts(&hidden_states, valid_len)?,
1053 );
1054
1055 let mut mask = vec![0u8; hidden_states.dim(1)?];
1056 for value in &mut mask[..valid_len] {
1057 *value = 1;
1058 }
1059 let attention_mask = Tensor::from_vec(mask, (1, hidden_states.dim(1)?), device)?;
1060 Ok((hidden_states, attention_mask))
1061 }
1062
1063 fn spill_conditioning_to_cpu(
1064 hidden_states: Tensor,
1065 attention_mask: Tensor,
1066 ) -> Result<(Tensor, Tensor)> {
1067 Ok((
1068 hidden_states
1069 .to_device(&Device::Cpu)?
1070 .to_dtype(DType::F32)?,
1071 attention_mask.to_device(&Device::Cpu)?,
1072 ))
1073 }
1074
1075 fn maybe_spill_conditioning(
1076 use_cpu_staging: bool,
1077 hidden_states: Tensor,
1078 attention_mask: Tensor,
1079 ) -> Result<(Tensor, Tensor)> {
1080 if use_cpu_staging {
1081 Self::spill_conditioning_to_cpu(hidden_states, attention_mask)
1082 } else {
1083 Ok((hidden_states, attention_mask))
1084 }
1085 }
1086
1087 fn transformer_paths(&self) -> Vec<std::path::PathBuf> {
1089 if !self.base.paths.transformer_shards.is_empty() {
1090 self.base.paths.transformer_shards.clone()
1091 } else {
1092 vec![self.base.paths.transformer.clone()]
1093 }
1094 }
1095
1096 fn detect_is_quantized(&self) -> bool {
1097 self.base
1098 .paths
1099 .transformer
1100 .extension()
1101 .and_then(|e| e.to_str())
1102 .map(|e| e.eq_ignore_ascii_case("gguf"))
1103 .unwrap_or(false)
1104 }
1105
1106 fn validate_paths(&self) -> Result<std::path::PathBuf> {
1108 let text_tokenizer_path =
1109 self.base.paths.text_tokenizer.as_ref().ok_or_else(|| {
1110 anyhow::anyhow!("text tokenizer path required for Qwen-Image models")
1111 })?;
1112 if !text_tokenizer_path.exists() {
1113 bail!(
1114 "text tokenizer file not found: {}",
1115 text_tokenizer_path.display()
1116 );
1117 }
1118
1119 let xformer_paths = self.transformer_paths();
1120 for path in &xformer_paths {
1121 if !path.exists() {
1122 bail!("transformer file not found: {}", path.display());
1123 }
1124 }
1125 if !self.base.paths.vae.exists() {
1126 bail!("VAE file not found: {}", self.base.paths.vae.display());
1127 }
1128
1129 Ok(text_tokenizer_path.clone())
1130 }
1131
1132 fn quantized_cuda_cfg_headroom(width: usize, height: usize) -> u64 {
1133 let native_pixels = (QWEN_NATIVE_WIDTH * QWEN_NATIVE_HEIGHT) as f64;
1134 let pixels = (width.max(1) * height.max(1)) as f64;
1135 let scaled =
1136 (QWEN_GGUF_NATIVE_CFG_HEADROOM as f64 * (pixels / native_pixels)).round() as u64;
1137 scaled.max(QWEN_GGUF_MIN_CFG_HEADROOM)
1138 }
1139
1140 fn should_split_cfg_quantized_cuda(
1141 transformer_size: u64,
1142 free_vram: u64,
1143 width: usize,
1144 height: usize,
1145 ) -> bool {
1146 if free_vram == 0 {
1147 return true;
1150 }
1151 let estimated_peak =
1152 transformer_size.saturating_add(Self::quantized_cuda_cfg_headroom(width, height));
1153 estimated_peak > free_vram
1154 }
1155
1156 fn load_transformer(
1158 &self,
1159 device: &Device,
1160 dtype: DType,
1161 cfg: &QwenImageConfig,
1162 width: usize,
1163 height: usize,
1164 ) -> Result<QwenImageTransformer> {
1165 let active_loras = &self.pending_loras;
1166 let has_lora = !active_loras.is_empty();
1167 if self.detect_is_quantized() {
1168 let transformer_size = std::fs::metadata(&self.base.paths.transformer)
1169 .map(|m| m.len())
1170 .unwrap_or(0);
1171 let free = usable_free_vram_bytes(self.base.gpu_ordinal).unwrap_or(0);
1173 let split_cfg_for_memory = device.is_cuda()
1174 && (self.offload
1175 || Self::should_split_cfg_quantized_cuda(
1176 transformer_size,
1177 free,
1178 width,
1179 height,
1180 ));
1181 if self.offload && device.is_cuda() {
1182 self.base.progress.info(
1183 "Quantized Qwen CUDA offload requested — using low-memory split-CFG mode until GGUF block offload lands",
1184 );
1185 } else if split_cfg_for_memory {
1186 let estimated_peak = transformer_size
1187 .saturating_add(Self::quantized_cuda_cfg_headroom(width, height));
1188 self.base.progress.info(&format!(
1189 "Using low-memory quantized Qwen CUDA path (est. peak {}, {} free at {}x{})",
1190 fmt_gb(estimated_peak),
1191 fmt_gb(free),
1192 width,
1193 height,
1194 ));
1195 }
1196 let vb = if has_lora {
1197 let adapters = super::lora::load_lora_adapters(active_loras, &self.base.progress)?;
1198 let specs: Vec<super::lora::QwenImageLoraSpec<'_>> = adapters
1199 .iter()
1200 .zip(active_loras.iter())
1201 .map(|(adapter, w)| super::lora::QwenImageLoraSpec {
1202 adapter: adapter.as_ref(),
1203 scale: w.scale,
1204 path_hash: super::lora::lora_path_hash(&w.path),
1205 })
1206 .collect();
1207 super::lora::gguf_lora_var_builder(
1208 &self.base.paths.transformer,
1209 &specs,
1210 device,
1211 &self.base.progress,
1212 None,
1213 )?
1214 } else {
1215 quantized_var_builder::VarBuilder::from_gguf(&self.base.paths.transformer, device)?
1216 };
1217 Ok(QwenImageTransformer::Quantized(
1218 QuantizedQwenImageTransformer2DModel::new(cfg, vb, device, !split_cfg_for_memory)?,
1219 ))
1220 } else {
1221 let xformer_paths = self.transformer_paths();
1222 let is_fp8 = xformer_paths
1223 .first()
1224 .map(|p| safetensors_is_fp8(p))
1225 .unwrap_or(false);
1226
1227 let mem_size: u64 = xformer_paths
1231 .iter()
1232 .filter_map(|p| std::fs::metadata(p).ok())
1233 .map(|m| m.len())
1234 .sum();
1235 let free = usable_free_vram_bytes(self.base.gpu_ordinal).unwrap_or(0);
1237 let activation_budget = crate::device::activation_bytes(
1240 width as u32,
1241 height as u32,
1242 2,
1243 crate::device::dtype_bytes(dtype),
1244 crate::device::ActivationFamily::QwenImageDit,
1245 );
1246 let use_offload =
1247 self.offload || crate::device::should_offload(mem_size, free, activation_budget);
1248
1249 if is_fp8 {
1250 self.base
1251 .progress
1252 .info("Detected FP8 safetensors — loading with scale dequantization");
1253 }
1254
1255 if use_offload {
1256 if has_lora {
1257 bail!(
1258 "Qwen-Image LoRA support is not yet wired through the block-offload \
1259 transformer path. Disable offload (drop --offload / unset MOLD_OFFLOAD), \
1260 or pick a checkpoint that fits without offload, to use LoRAs."
1261 );
1262 }
1263 let (gpu_vb, cpu_vb) = if is_fp8 {
1265 let gpu = crate::weight_loader::load_fp8_safetensors(
1266 &xformer_paths,
1267 device,
1268 "Qwen-Image transformer (offload, GPU)",
1269 &self.base.progress,
1270 )?;
1271 let cpu = crate::weight_loader::load_fp8_safetensors(
1272 &xformer_paths,
1273 &Device::Cpu,
1274 "Qwen-Image transformer (offload, CPU)",
1275 &self.base.progress,
1276 )?;
1277 (gpu, cpu)
1278 } else {
1279 let gpu = crate::weight_loader::load_safetensors_with_progress(
1280 &xformer_paths,
1281 dtype,
1282 device,
1283 "Qwen-Image transformer (offload, GPU)",
1284 &self.base.progress,
1285 )?;
1286 let cpu = unsafe {
1287 candle_nn::VarBuilder::from_mmaped_safetensors(
1288 &xformer_paths
1289 .iter()
1290 .map(|p| p.as_path())
1291 .collect::<Vec<_>>(),
1292 DType::BF16,
1293 &Device::Cpu,
1294 )?
1295 };
1296 (gpu, cpu)
1297 };
1298 Ok(QwenImageTransformer::Offloaded(
1299 super::offload::OffloadedQwenImageTransformer::load(
1300 gpu_vb,
1301 cpu_vb,
1302 cfg,
1303 device,
1304 self.base.gpu_ordinal,
1305 &self.base.progress,
1306 )?,
1307 ))
1308 } else {
1309 let xformer_vb = if has_lora {
1310 self.build_bf16_lora_var_builder(
1311 &xformer_paths,
1312 dtype,
1313 device,
1314 is_fp8,
1315 active_loras,
1316 )?
1317 } else if is_fp8 {
1318 crate::weight_loader::load_fp8_safetensors(
1319 &xformer_paths,
1320 device,
1321 "Qwen-Image transformer",
1322 &self.base.progress,
1323 )?
1324 } else {
1325 crate::weight_loader::load_safetensors_with_progress(
1326 &xformer_paths,
1327 dtype,
1328 device,
1329 "Qwen-Image transformer",
1330 &self.base.progress,
1331 )?
1332 };
1333 Ok(QwenImageTransformer::BF16(
1334 QwenImageTransformer2DModel::new(cfg, xformer_vb)?,
1335 ))
1336 }
1337 }
1338 }
1339
1340 fn build_bf16_lora_var_builder<'a>(
1345 &self,
1346 xformer_paths: &[std::path::PathBuf],
1347 dtype: DType,
1348 device: &Device,
1349 is_fp8: bool,
1350 loras: &[mold_core::LoraWeight],
1351 ) -> Result<candle_nn::VarBuilder<'a>> {
1352 let adapters = super::lora::load_lora_adapters(loras, &self.base.progress)?;
1353 let specs: Vec<super::lora::QwenImageLoraSpec<'_>> = adapters
1354 .iter()
1355 .zip(loras.iter())
1356 .map(|(adapter, w)| super::lora::QwenImageLoraSpec {
1357 adapter: adapter.as_ref(),
1358 scale: w.scale,
1359 path_hash: super::lora::lora_path_hash(&w.path),
1360 })
1361 .collect();
1362
1363 let path_refs: Vec<&std::path::Path> = xformer_paths.iter().map(|p| p.as_path()).collect();
1364 let tensors = unsafe { candle_core::safetensors::MmapedSafetensors::multi(&path_refs)? };
1365 let inner: Box<dyn candle_nn::var_builder::SimpleBackend> = if is_fp8 {
1366 self.base
1371 .progress
1372 .info("Detected FP8 safetensors — loading with LoRA-merging wrapper");
1373 Box::new(crate::weight_loader::NativeFp8Backend::from_mmap(tensors))
1374 } else {
1375 Box::new(tensors)
1378 };
1379
1380 let wrapped =
1381 super::lora::wrap_backend_with_lora(inner, &specs, &self.base.progress, None)?;
1382
1383 let target_dtype = if is_fp8 { DType::BF16 } else { dtype };
1384 Ok(candle_nn::VarBuilder::from_backend(
1385 wrapped,
1386 target_dtype,
1387 device.clone(),
1388 ))
1389 }
1390
1391 fn load_vae(&self, device: &Device, dtype: DType) -> Result<QwenImageVae> {
1393 let vb = self.load_vae_var_builder(device, dtype)?;
1394 Ok(QwenImageVae::from_var_builder(vb, device, dtype)?)
1395 }
1396
1397 fn load_vae_cpu_tensors(&self) -> Result<Option<Arc<HashMap<String, Tensor>>>> {
1398 let Some(shared_pool) = &self.shared_pool else {
1399 return Ok(None);
1400 };
1401 shared_pool
1402 .lock()
1403 .unwrap()
1404 .load_safetensors_cpu_tensors(std::slice::from_ref(&self.base.paths.vae))
1405 }
1406
1407 fn load_vae_var_builder<'a>(
1408 &self,
1409 device: &Device,
1410 dtype: DType,
1411 ) -> Result<candle_nn::VarBuilder<'a>> {
1412 if let Some(tensors) = self.load_vae_cpu_tensors()? {
1413 return Ok(encoders::park::varbuilder_from_parked(
1414 tensors.as_ref(),
1415 dtype,
1416 device,
1417 ));
1418 }
1419
1420 crate::weight_loader::load_safetensors_with_progress(
1421 std::slice::from_ref(&self.base.paths.vae),
1422 dtype,
1423 device,
1424 "Qwen-Image VAE",
1425 &self.base.progress,
1426 )
1427 }
1428
1429 fn resolve_text_encoder_source(
1434 &self,
1435 gpu_device: &Device,
1436 free_vram: u64,
1437 usage: Qwen2TextEncoderUsage,
1438 ) -> Result<ResolvedQwen2TextEncoder> {
1439 let preference = std::env::var("MOLD_QWEN2_VARIANT").ok();
1440 self.resolve_text_encoder_source_with_preference(
1441 gpu_device,
1442 free_vram,
1443 usage,
1444 preference.as_deref(),
1445 )
1446 }
1447
1448 fn resolve_text_encoder_source_with_preference(
1449 &self,
1450 gpu_device: &Device,
1451 free_vram: u64,
1452 usage: Qwen2TextEncoderUsage,
1453 preference: Option<&str>,
1454 ) -> Result<ResolvedQwen2TextEncoder> {
1455 let is_cuda = gpu_device.is_cuda();
1456 let is_metal = gpu_device.is_metal();
1457 let bf16_size_bytes = self
1458 .base
1459 .paths
1460 .text_encoder_files
1461 .iter()
1462 .filter_map(|p| std::fs::metadata(p).ok())
1463 .map(|m| m.len())
1464 .sum();
1465 if self.is_edit_family() {
1466 let mut resolved = Self::choose_text_encoder_source(
1467 preference,
1468 is_cuda,
1469 is_metal,
1470 free_vram,
1471 bf16_size_bytes,
1472 Qwen2TextEncoderUsage::Resident,
1473 )?;
1474 resolved.vision_paths = self.base.paths.text_encoder_files.clone();
1475 if resolved.is_gguf {
1476 let variant = mold_core::manifest::find_qwen2_vl_variant(&resolved.variant_label)
1477 .ok_or_else(|| {
1478 anyhow::anyhow!("unknown Qwen2.5-VL variant '{}'", resolved.variant_label)
1479 })?;
1480 resolved.paths = vec![
1481 crate::encoders::variant_resolution::resolve_qwen2_vl_gguf_path(
1482 &self.base.progress,
1483 variant,
1484 )?,
1485 ];
1486 } else {
1487 resolved.paths = self.base.paths.text_encoder_files.clone();
1488 }
1489 return Ok(resolved);
1490 }
1491 let mut resolved = Self::choose_text_encoder_source(
1492 preference,
1493 is_cuda,
1494 is_metal,
1495 free_vram,
1496 bf16_size_bytes,
1497 usage,
1498 )?;
1499
1500 if resolved.is_gguf {
1501 let variant = mold_core::manifest::find_qwen2_vl_variant(&resolved.variant_label)
1502 .ok_or_else(|| {
1503 anyhow::anyhow!("unknown Qwen2.5-VL variant '{}'", resolved.variant_label)
1504 })?;
1505 resolved.paths = vec![
1506 crate::encoders::variant_resolution::resolve_qwen2_vl_gguf_path(
1507 &self.base.progress,
1508 variant,
1509 )?,
1510 ];
1511 } else {
1512 resolved.paths = self.base.paths.text_encoder_files.clone();
1513 }
1514 resolved.vision_paths = vec![];
1515
1516 match preference {
1517 Some(tag) if tag != "auto" && tag != "bf16" => self.base.progress.info(&format!(
1518 "Using quantized Qwen2.5-VL {} ({}) on {} (explicit)",
1519 resolved.variant_label,
1520 fmt_gb(resolved.size_bytes),
1521 if resolved.auto_use_gpu { "GPU" } else { "CPU" },
1522 )),
1523 Some("bf16") => {}
1524 _ if is_metal && resolved.is_gguf && resolved.variant_label == "q6" => self
1525 .base
1526 .progress
1527 .info(&format!(
1528 "Metal auto mode selected quantized Qwen2.5-VL {} ({}) for lower memory pressure",
1529 resolved.variant_label,
1530 fmt_gb(resolved.size_bytes),
1531 )),
1532 _ if is_metal && resolved.is_gguf => self.base.progress.info(&format!(
1533 "Metal auto mode forcing quantized Qwen2.5-VL {} ({}) to avoid BF16 memory pressure",
1534 resolved.variant_label,
1535 fmt_gb(resolved.size_bytes),
1536 )),
1537 _ if is_cuda && resolved.is_gguf && resolved.auto_use_gpu => self.base.progress.info(
1538 &format!(
1539 "CUDA auto mode selected quantized Qwen2.5-VL {} ({}) on GPU",
1540 resolved.variant_label,
1541 fmt_gb(resolved.size_bytes),
1542 ),
1543 ),
1544 _ if is_cuda && resolved.is_gguf => self.base.progress.info(&format!(
1545 "CUDA auto mode selected quantized Qwen2.5-VL {} ({}) on CPU to avoid large BF16 host residency",
1546 resolved.variant_label,
1547 fmt_gb(resolved.size_bytes),
1548 )),
1549 _ => {}
1550 }
1551
1552 Ok(resolved)
1553 }
1554
1555 fn can_keep_transformer_hot_for_vae(loaded: &LoadedQwenImage) -> bool {
1556 Self::qwen_transformer_can_stay_hot_for_vae(
1557 loaded.device.is_cuda(),
1558 loaded.vae_device.is_cuda(),
1559 matches!(
1560 loaded.transformer.as_ref(),
1561 Some(QwenImageTransformer::Quantized(_))
1562 ),
1563 )
1564 }
1565
1566 fn qwen_transformer_can_stay_hot_for_vae(
1567 transformer_is_cuda: bool,
1568 vae_is_cuda: bool,
1569 transformer_is_quantized: bool,
1570 ) -> bool {
1571 transformer_is_cuda && vae_is_cuda && transformer_is_quantized
1572 }
1573
1574 fn decode_vae_gpu_only(
1575 latents: &Tensor,
1576 vae: &QwenImageVae,
1577 vae_device: &Device,
1578 sync_device: &Device,
1579 progress: &ProgressReporter,
1580 prefer_tiled: bool,
1581 ) -> Result<Tensor> {
1582 if vae_device.is_cuda() && prefer_tiled {
1583 progress.info("Selecting tiled GPU VAE decode proactively");
1584 return Self::decode_vae_tiled(latents, vae, vae_device, progress);
1585 }
1586
1587 let decode_latents = latents.to_device(vae_device)?.to_dtype(DType::F32)?;
1588 match vae.decode(&decode_latents) {
1589 Ok(image) => Ok(image),
1590 Err(e) if vae_device.is_cuda() && Self::is_oom_error(&e) => {
1591 progress.info(
1592 "Resident-transformer VAE decode OOM on GPU — retrying with tiled GPU decode before dropping transformer",
1593 );
1594 sync_device.synchronize()?;
1595 Self::decode_vae_tiled(latents, vae, vae_device, progress)
1596 }
1597 Err(e) => Err(e.into()),
1598 }
1599 }
1600
1601 fn load_text_encoder(
1602 &self,
1603 resolved: &ResolvedQwen2TextEncoder,
1604 tokenizer_path: &std::path::PathBuf,
1605 tokenizer: Arc<Tokenizer>,
1606 device: &Device,
1607 dtype: DType,
1608 preload_weights: bool,
1609 ) -> Result<encoders::qwen2_text::Qwen2TextEncoder> {
1610 if resolved.is_gguf {
1611 if preload_weights {
1612 encoders::qwen2_text::Qwen2TextEncoder::load_gguf_with_tokenizer(
1613 &resolved.paths[0],
1614 tokenizer_path,
1615 Some(tokenizer),
1616 device,
1617 dtype,
1618 &resolved.vision_paths,
1619 &self.base.progress,
1620 )
1621 } else {
1622 encoders::qwen2_text::Qwen2TextEncoder::prepare_gguf_with_tokenizer(
1623 &resolved.paths[0],
1624 tokenizer_path,
1625 Some(tokenizer),
1626 device,
1627 dtype,
1628 &resolved.vision_paths,
1629 )
1630 }
1631 } else {
1632 let is_fp8 = text_encoder_is_fp8(&resolved.paths);
1633 if is_fp8 {
1634 self.base
1635 .progress
1636 .info("Detected FP8 text encoder — loading as BF16 on GPU");
1637 }
1638 if preload_weights {
1639 encoders::qwen2_text::Qwen2TextEncoder::load_bf16_with_tokenizer(
1640 &resolved.paths,
1641 tokenizer_path,
1642 Some(tokenizer),
1643 device,
1644 dtype,
1645 self.is_edit_family(),
1646 &self.base.progress,
1647 )
1648 } else {
1649 encoders::qwen2_text::Qwen2TextEncoder::prepare_bf16_with_tokenizer(
1650 &resolved.paths,
1651 tokenizer_path,
1652 Some(tokenizer),
1653 device,
1654 dtype,
1655 self.is_edit_family(),
1656 )
1657 }
1658 }
1659 }
1660
1661 fn resolve_text_encoder_plan(
1663 &self,
1664 gpu_device: &Device,
1665 resolved: &ResolvedQwen2TextEncoder,
1666 free_vram: u64,
1667 ) -> (Qwen2TextEncoderPlan, String) {
1668 let is_cuda = gpu_device.is_cuda();
1669 let is_metal = gpu_device.is_metal();
1670 let plan = Self::qwen2_text_encoder_plan_for_mode(
1671 Qwen2TextEncoderMode::from_env(),
1672 is_cuda,
1673 is_metal,
1674 resolved,
1675 );
1676 let label = if plan.use_gpu { "GPU" } else { "CPU" };
1677 if plan.use_cpu_staging {
1678 self.base
1679 .progress
1680 .info("Qwen2.5 text encoder on GPU with CPU staging after encoding");
1681 } else if !plan.use_gpu {
1682 if resolved.is_gguf {
1683 self.base.progress.info(&format!(
1684 "Qwen2.5 text encoder on CPU ({} variant {}, {} free)",
1685 resolved.variant_label,
1686 fmt_gb(resolved.size_bytes),
1687 fmt_gb(free_vram),
1688 ));
1689 } else if is_metal || is_cuda {
1690 self.base.progress.info(&format!(
1691 "Qwen2.5 text encoder on CPU ({} free < {} threshold)",
1692 fmt_gb(free_vram),
1693 fmt_gb(QWEN2_FP16_VRAM_THRESHOLD),
1694 ));
1695 }
1696 }
1697 (plan, label.to_string())
1698 }
1699
1700 fn qwen2_text_encoder_plan_for_mode(
1701 mode: Qwen2TextEncoderMode,
1702 is_cuda: bool,
1703 is_metal: bool,
1704 resolved: &ResolvedQwen2TextEncoder,
1705 ) -> Qwen2TextEncoderPlan {
1706 match mode {
1707 Qwen2TextEncoderMode::Gpu => Qwen2TextEncoderPlan {
1708 use_gpu: is_cuda || is_metal,
1709 use_cpu_staging: false,
1710 },
1711 Qwen2TextEncoderMode::CpuStage => Qwen2TextEncoderPlan {
1712 use_gpu: is_cuda || is_metal,
1713 use_cpu_staging: is_cuda || is_metal,
1714 },
1715 Qwen2TextEncoderMode::Cpu => Qwen2TextEncoderPlan {
1716 use_gpu: false,
1717 use_cpu_staging: false,
1718 },
1719 Qwen2TextEncoderMode::Auto => Qwen2TextEncoderPlan {
1720 use_gpu: resolved.auto_use_gpu,
1721 use_cpu_staging: is_metal && resolved.auto_use_gpu && !resolved.is_gguf,
1722 },
1723 }
1724 }
1725
1726 pub fn load(&mut self) -> Result<()> {
1732 if self.base.loaded.is_some() {
1733 return Ok(());
1734 }
1735
1736 if self.base.load_strategy == LoadStrategy::Sequential {
1738 return Ok(());
1739 }
1740
1741 tracing::info!(model = %self.base.model_name, "loading Qwen-Image model components...");
1742
1743 let text_tokenizer_path = self.validate_paths()?;
1744 let transformer_ref = effective_device_ref(
1745 self.pending_placement.as_ref(),
1746 |adv| Some(adv.transformer),
1747 false,
1748 );
1749 let device = crate::device::resolve_device(Some(transformer_ref), || {
1750 crate::device::create_device(self.base.gpu_ordinal, &self.base.progress)
1751 })?;
1752 let transformer_cfg = self.transformer_config();
1753 let transformer_is_quantized = self.detect_is_quantized();
1754 let dtype = crate::engine::gpu_dtype(&device);
1758
1759 let xformer_paths = self.transformer_paths();
1761 let xformer_label = if transformer_is_quantized {
1762 "Loading Qwen-Image transformer (quantized)".to_string()
1763 } else {
1764 format!(
1765 "Loading Qwen-Image transformer ({} shards)",
1766 xformer_paths.len()
1767 )
1768 };
1769 self.base.progress.stage_start(&xformer_label);
1770 let xformer_start = Instant::now();
1771 let transformer = self.load_transformer(
1772 &device,
1773 dtype,
1774 &transformer_cfg,
1775 QWEN_NATIVE_WIDTH,
1776 QWEN_NATIVE_HEIGHT,
1777 )?;
1778 self.base
1779 .progress
1780 .stage_done(&xformer_label, xformer_start.elapsed());
1781 tracing::info!("Qwen-Image transformer loaded");
1782
1783 let free_raw = free_vram_bytes(self.base.gpu_ordinal).unwrap_or(0);
1786 let free = usable_free_vram_bytes(self.base.gpu_ordinal).unwrap_or(0);
1787 let is_cuda = device.is_cuda();
1788 let is_metal = device.is_metal();
1789 if free_raw > 0 {
1790 self.base.progress.info(&format!(
1791 "Free VRAM after transformer: {}",
1792 fmt_gb(free_raw)
1793 ));
1794 }
1795
1796 let vae_on_gpu = should_use_gpu(is_cuda, is_metal, free, VAE_DECODE_VRAM_THRESHOLD);
1797 let vae_ref =
1798 effective_device_ref(self.pending_placement.as_ref(), |adv| Some(adv.vae), false);
1799 let vae_device = crate::device::resolve_device(Some(vae_ref), || {
1800 Ok(if vae_on_gpu {
1801 device.clone()
1802 } else {
1803 Device::Cpu
1804 })
1805 })?;
1806 let vae_on_gpu = !vae_device.is_cpu();
1807 let vae_dtype = DType::F32;
1810 let vae_device_label = if vae_on_gpu { "GPU" } else { "CPU" };
1811
1812 let vae_label = format!("Loading Qwen-Image VAE ({}, F32)", vae_device_label);
1814 self.base.progress.stage_start(&vae_label);
1815 let vae_start = Instant::now();
1816 let vae = self.load_vae(&vae_device, vae_dtype)?;
1817 self.base
1818 .progress
1819 .stage_done(&vae_label, vae_start.elapsed());
1820
1821 let resolved_text_encoder =
1823 self.resolve_text_encoder_source(&device, free, Qwen2TextEncoderUsage::Resident)?;
1824 let (te_plan, te_auto_device_label) =
1825 self.resolve_text_encoder_plan(&device, &resolved_text_encoder, free);
1826 let qwen_ref = effective_device_ref(self.pending_placement.as_ref(), |adv| adv.qwen, true);
1827 let auto_te_device = if te_plan.use_gpu {
1828 device.clone()
1829 } else {
1830 Device::Cpu
1831 };
1832 let te_device =
1833 crate::device::resolve_device(Some(qwen_ref), || Ok(auto_te_device.clone()))?;
1834 let te_use_gpu = !te_device.is_cpu();
1835 let te_device_label: String = if te_use_gpu == te_plan.use_gpu {
1836 te_auto_device_label
1837 } else if te_use_gpu {
1838 "GPU".into()
1839 } else {
1840 "CPU".into()
1841 };
1842 let te_dtype = Self::text_encoder_load_dtype(te_use_gpu, dtype);
1843
1844 let preload_text_encoder = self.should_preload_text_encoder();
1845 let te_label = if resolved_text_encoder.is_gguf {
1846 if preload_text_encoder {
1847 format!(
1848 "Loading Qwen2.5 text encoder ({} GGUF, {})",
1849 resolved_text_encoder.variant_label, te_device_label
1850 )
1851 } else {
1852 format!(
1853 "Preparing Qwen2.5 text encoder ({} GGUF, {})",
1854 resolved_text_encoder.variant_label, te_device_label
1855 )
1856 }
1857 } else if preload_text_encoder {
1858 format!(
1859 "Loading Qwen2.5 text encoder ({} shards, {})",
1860 resolved_text_encoder.paths.len(),
1861 te_device_label,
1862 )
1863 } else {
1864 format!(
1865 "Preparing Qwen2.5 text encoder ({} shards, {})",
1866 resolved_text_encoder.paths.len(),
1867 te_device_label,
1868 )
1869 };
1870 self.base.progress.stage_start(&te_label);
1871 let te_start = Instant::now();
1872 let text_tokenizer = self.load_text_tokenizer(&text_tokenizer_path)?;
1873 let text_encoder = self.load_text_encoder(
1874 &resolved_text_encoder,
1875 &text_tokenizer_path,
1876 text_tokenizer,
1877 &te_device,
1878 te_dtype,
1879 preload_text_encoder,
1880 )?;
1881 self.base.progress.stage_done(&te_label, te_start.elapsed());
1882 if preload_text_encoder {
1883 tracing::info!(device = %te_device_label, "Qwen2.5 text encoder loaded");
1884 } else {
1885 tracing::info!(device = %te_device_label, "Qwen2.5 text encoder prepared for staged loading");
1886 }
1887
1888 self.base.loaded = Some(LoadedQwenImage {
1889 transformer: Some(transformer),
1890 text_encoder,
1891 vae,
1892 vae_path: self.base.paths.vae.clone(),
1893 transformer_cfg,
1894 device,
1895 vae_device,
1896 dtype,
1897 });
1898
1899 tracing::info!(model = %self.base.model_name, "all Qwen-Image components loaded");
1900 Ok(())
1901 }
1902
1903 fn reload_transformer(
1905 &self,
1906 loaded: &mut LoadedQwenImage,
1907 width: usize,
1908 height: usize,
1909 ) -> Result<()> {
1910 let transformer = self.load_transformer(
1911 &loaded.device,
1912 loaded.dtype,
1913 &loaded.transformer_cfg,
1914 width,
1915 height,
1916 )?;
1917 loaded.transformer = Some(transformer);
1918 Ok(())
1919 }
1920
1921 fn generate_sequential(&mut self, req: &GenerateRequest) -> Result<GenerateResponse> {
1923 let text_tokenizer_path = self.validate_paths()?;
1924 let transformer_cfg = self.transformer_config();
1925
1926 let transformer_ref = effective_device_ref(
1927 self.pending_placement.as_ref(),
1928 |adv| Some(adv.transformer),
1929 false,
1930 );
1931 let device = crate::device::resolve_device(Some(transformer_ref), || {
1932 crate::device::create_device(self.base.gpu_ordinal, &self.base.progress)
1933 })?;
1934 let dtype = crate::engine::gpu_dtype(&device);
1935 let transformer_is_quantized = self.detect_is_quantized();
1936
1937 let start = Instant::now();
1938 let seed = req.seed.unwrap_or_else(rand_seed);
1939
1940 let width = req.width as usize;
1941 let height = req.height as usize;
1942 let free = usable_free_vram_bytes(self.base.gpu_ordinal).unwrap_or(0);
1945 let resolved_text_encoder =
1946 self.resolve_text_encoder_source(&device, free, Qwen2TextEncoderUsage::Sequential)?;
1947 let (plan, _device_label) =
1948 self.resolve_text_encoder_plan(&device, &resolved_text_encoder, free);
1949 let use_cpu_staging = plan.use_cpu_staging;
1950
1951 tracing::info!(
1952 prompt = %req.prompt,
1953 seed, width, height,
1954 steps = req.steps,
1955 "starting sequential Qwen-Image generation"
1956 );
1957
1958 self.base
1959 .progress
1960 .info("Using sequential loading (load-use-drop) to minimize peak memory");
1961
1962 let use_cfg = req.guidance > 1.0;
1964 let prompt_key = prompt_text_key(&req.prompt);
1965 let uncond_key = prompt_text_key(QWEN_EMPTY_NEGATIVE_PROMPT);
1966 let (prompt_cached, uncond_cached) = {
1967 let mut cache = self.prompt_cache.lock().expect("cache poisoned");
1968 let prompt_cached = cache.get_cloned(&prompt_key);
1969 let uncond_cached = if use_cfg {
1970 cache.get_cloned(&uncond_key)
1971 } else {
1972 None
1973 };
1974 (prompt_cached, uncond_cached)
1975 };
1976 let both_cached = prompt_cached.is_some() && (!use_cfg || uncond_cached.is_some());
1977
1978 let (mut encoder_hidden_states, mut encoder_attention_mask, mut uncond_hs, mut uncond_mask) =
1979 if both_cached {
1980 self.base.progress.cache_hit("prompt conditioning");
1981 let cached = prompt_cached.unwrap();
1982 let restore_device = if use_cpu_staging {
1983 &Device::Cpu
1984 } else {
1985 &device
1986 };
1987 let restore_dtype = if use_cpu_staging { DType::F32 } else { dtype };
1988 let (hs, mask) = cached.restore(restore_device, restore_dtype)?;
1989 let (u_hs, u_mask) = if use_cfg {
1990 let ucached = uncond_cached.unwrap();
1991 let (u_hs, u_mask) = ucached.restore(restore_device, restore_dtype)?;
1992 (Some(u_hs), Some(u_mask))
1993 } else {
1994 (None, None)
1995 };
1996 (hs, mask, u_hs, u_mask)
1997 } else {
1998 let (te_plan, te_auto_device_label) =
1999 self.resolve_text_encoder_plan(&device, &resolved_text_encoder, free);
2000 let qwen_ref =
2001 effective_device_ref(self.pending_placement.as_ref(), |adv| adv.qwen, true);
2002 let auto_te_device = if te_plan.use_gpu {
2003 device.clone()
2004 } else {
2005 Device::Cpu
2006 };
2007 let te_device =
2008 crate::device::resolve_device(Some(qwen_ref), || Ok(auto_te_device.clone()))?;
2009 let te_use_gpu = !te_device.is_cpu();
2010 let te_device_label: String = if te_use_gpu == te_plan.use_gpu {
2011 te_auto_device_label
2012 } else if te_use_gpu {
2013 "GPU".into()
2014 } else {
2015 "CPU".into()
2016 };
2017 let te_dtype = Self::text_encoder_load_dtype(te_use_gpu, dtype);
2018
2019 let te_label = if resolved_text_encoder.is_gguf {
2020 format!(
2021 "Loading Qwen2.5 text encoder ({} GGUF, {})",
2022 resolved_text_encoder.variant_label, te_device_label
2023 )
2024 } else {
2025 format!(
2026 "Loading Qwen2.5 text encoder ({} shards, {})",
2027 resolved_text_encoder.paths.len(),
2028 te_device_label,
2029 )
2030 };
2031 if te_plan.use_cpu_staging && device.is_metal() && !resolved_text_encoder.is_gguf {
2032 self.base.progress.info(
2033 "Skipping hard preflight for Qwen2.5 text encoder on Metal; sequential mode spills prompt conditioning to CPU after encoding",
2034 );
2035 } else {
2036 let te_activation_budget = crate::device::activation_bytes(
2037 req.width,
2038 req.height,
2039 1,
2040 crate::device::dtype_bytes(te_dtype),
2041 crate::device::ActivationFamily::SmallTransformer,
2042 );
2043 preflight_memory_check(
2044 "Qwen2.5 text encoder",
2045 resolved_text_encoder.size_bytes,
2046 te_activation_budget,
2047 )?;
2048 }
2049
2050 if let Some(status) = memory_status_string() {
2051 self.base.progress.info(&status);
2052 }
2053
2054 self.base.progress.stage_start(&te_label);
2055 let te_start = Instant::now();
2056 let text_tokenizer = self.load_text_tokenizer(&text_tokenizer_path)?;
2057 let mut text_encoder = self.load_text_encoder(
2058 &resolved_text_encoder,
2059 &text_tokenizer_path,
2060 text_tokenizer,
2061 &te_device,
2062 te_dtype,
2063 true,
2064 )?;
2065 self.base.progress.stage_done(&te_label, te_start.elapsed());
2066
2067 let (hs, mask) = Self::encode_prompt_cached(
2068 &self.base.progress,
2069 &self.prompt_cache,
2070 &mut text_encoder,
2071 &req.prompt,
2072 &device,
2073 dtype,
2074 )?;
2075 let (hs, mask) = Self::maybe_spill_conditioning(use_cpu_staging, hs, mask)?;
2076
2077 let (u_hs, u_mask) = if use_cfg {
2078 let (hs, mask) = Self::encode_prompt_cached(
2079 &self.base.progress,
2080 &self.prompt_cache,
2081 &mut text_encoder,
2082 QWEN_EMPTY_NEGATIVE_PROMPT,
2083 &device,
2084 dtype,
2085 )?;
2086 let (hs, mask) = Self::maybe_spill_conditioning(use_cpu_staging, hs, mask)?;
2087 (Some(hs), Some(mask))
2088 } else {
2089 (None, None)
2090 };
2091
2092 drop(text_encoder);
2093 device.synchronize()?;
2095 if let Some(status) = crate::device::memory_status_string() {
2096 if use_cpu_staging {
2097 self.base.progress.info(&format!(
2098 "Freed Qwen2.5 text encoder and spilled prompt conditioning to CPU — {status}"
2099 ));
2100 } else {
2101 self.base
2102 .progress
2103 .info(&format!("Freed Qwen2.5 text encoder — {status}"));
2104 }
2105 } else {
2106 if use_cpu_staging {
2107 self.base.progress.info(
2108 "Freed Qwen2.5 text encoder and spilled prompt conditioning to CPU",
2109 );
2110 } else {
2111 self.base.progress.info("Freed Qwen2.5 text encoder");
2112 }
2113 }
2114
2115 (hs, mask, u_hs, u_mask)
2116 };
2117
2118 if use_cfg {
2119 let ((cond_hs, cond_mask), (neg_hs, neg_mask)) = align_cfg_conditioning(
2120 &encoder_hidden_states,
2121 &encoder_attention_mask,
2122 uncond_hs.as_ref().expect("unconditional prompt missing"),
2123 uncond_mask.as_ref().expect("unconditional mask missing"),
2124 )?;
2125 encoder_hidden_states = cond_hs;
2126 encoder_attention_mask = cond_mask;
2127 uncond_hs = Some(neg_hs);
2128 uncond_mask = Some(neg_mask);
2129 }
2130
2131 let xformer_paths = self.transformer_paths();
2133 let xformer_size: u64 = xformer_paths
2134 .iter()
2135 .filter_map(|p| std::fs::metadata(p).ok())
2136 .map(|m| m.len())
2137 .sum();
2138 let xformer_activation_budget = crate::device::activation_bytes(
2139 req.width,
2140 req.height,
2141 if req.guidance > 1.0 { 2 } else { 1 },
2142 crate::device::dtype_bytes(dtype),
2143 crate::device::ActivationFamily::QwenImageDit,
2144 );
2145 preflight_memory_check(
2146 "Qwen-Image transformer",
2147 xformer_size,
2148 xformer_activation_budget,
2149 )?;
2150
2151 if let Some(status) = memory_status_string() {
2152 self.base.progress.info(&status);
2153 }
2154
2155 let xformer_label = if transformer_is_quantized {
2156 "Loading Qwen-Image transformer (quantized)".to_string()
2157 } else {
2158 format!(
2159 "Loading Qwen-Image transformer ({} shards)",
2160 xformer_paths.len()
2161 )
2162 };
2163 self.base.progress.stage_start(&xformer_label);
2164 let xformer_start = Instant::now();
2165 let transformer = self.load_transformer(&device, dtype, &transformer_cfg, width, height)?;
2166 self.base
2167 .progress
2168 .stage_done(&xformer_label, xformer_start.elapsed());
2169
2170 if use_cpu_staging {
2171 encoder_hidden_states = encoder_hidden_states.to_device(&device)?.to_dtype(dtype)?;
2172 encoder_attention_mask = encoder_attention_mask.to_device(&device)?;
2173 if let Some(hs) = uncond_hs.take() {
2174 uncond_hs = Some(hs.to_device(&device)?.to_dtype(dtype)?);
2175 }
2176 if let Some(mask) = uncond_mask.take() {
2177 uncond_mask = Some(mask.to_device(&device)?);
2178 }
2179 if let Some(status) = memory_status_string() {
2180 self.base.progress.info(&format!(
2181 "Restored prompt conditioning to GPU for denoising — {status}"
2182 ));
2183 } else {
2184 self.base
2185 .progress
2186 .info("Restored prompt conditioning to GPU for denoising");
2187 }
2188 }
2189
2190 let vae_downsample = 8;
2192 let latent_h = height / vae_downsample;
2193 let latent_w = width / vae_downsample;
2194 let is_img2img = req.source_image.is_some();
2195
2196 let (prepared_img2img_latents, inpaint_ctx) = if let Some(ref source_bytes) =
2198 req.source_image
2199 {
2200 let free_for_encode = usable_free_vram_bytes(self.base.gpu_ordinal).unwrap_or(0);
2202 let encode_on_gpu = should_use_gpu(
2203 device.is_cuda(),
2204 device.is_metal(),
2205 free_for_encode,
2206 VAE_DECODE_VRAM_THRESHOLD,
2207 );
2208 let encode_device = if encode_on_gpu {
2209 device.clone()
2210 } else {
2211 Device::Cpu
2212 };
2213 let encode_label = if encode_on_gpu { "GPU" } else { "CPU" };
2214
2215 let vae_label = format!("Loading Qwen-Image VAE ({}, F32) for encode", encode_label);
2216 self.base.progress.stage_start(&vae_label);
2217 let vae_start = Instant::now();
2218 let encode_vae = self.load_vae(&encode_device, DType::F32)?;
2219 self.base
2220 .progress
2221 .stage_done(&vae_label, vae_start.elapsed());
2222
2223 let encoded = Self::encode_vae_with_fallback(
2224 source_bytes,
2225 req.width,
2226 req.height,
2227 &encode_vae,
2228 &encode_device,
2229 &device,
2230 &self.base.progress,
2231 || self.load_vae(&Device::Cpu, DType::F32),
2232 )?;
2233 let encoded = encoded.to_device(&device)?.to_dtype(dtype)?;
2234 let start_sigma = QwenImageScheduler::new_img2img(
2235 req.steps as usize,
2236 image_seq_len(latent_h, latent_w, transformer_cfg.patch_size),
2237 req.strength,
2238 )
2239 .0
2240 .initial_sigma();
2241 let prepared = crate::img2img::prepare_flow_match_img2img(
2242 &encoded,
2243 seed,
2244 &[1, 16, latent_h, latent_w],
2245 start_sigma,
2246 req.mask_image.as_deref(),
2247 latent_h,
2248 latent_w,
2249 &device,
2250 dtype,
2251 )?;
2252
2253 drop(encode_vae);
2255 device.synchronize()?;
2256
2257 tracing::info!(
2258 strength = req.strength,
2259 "img2img: encoded source image to latents"
2260 );
2261
2262 (Some(prepared.initial_latents), prepared.inpaint_ctx)
2263 } else {
2264 (None, None)
2265 };
2266
2267 let image_seq_len = image_seq_len(latent_h, latent_w, transformer_cfg.patch_size);
2268 let (mut scheduler, num_steps) = if is_img2img {
2269 QwenImageScheduler::new_img2img(req.steps as usize, image_seq_len, req.strength)
2270 } else {
2271 let sched = QwenImageScheduler::new(req.steps as usize, image_seq_len);
2272 let n = sched.num_steps();
2273 (sched, n)
2274 };
2275
2276 let mut latents = if let Some(initial) = &prepared_img2img_latents {
2278 initial.clone()
2279 } else {
2280 let noise =
2281 crate::engine::seeded_randn(seed, &[1, 16, latent_h, latent_w], &device, dtype)?;
2282 (noise * scheduler.initial_sigma())?
2283 };
2284
2285 let denoise_label = format!("Denoising ({} steps)", num_steps);
2286 self.base.progress.stage_start(&denoise_label);
2287 let denoise_start = Instant::now();
2288
2289 if std::env::var_os("MOLD_QWEN_DEBUG").is_some() {
2290 eprintln!(
2291 "[qwen-debug] cfg={} guidance={:.1} image_seq_len={} sigmas[0]={:.4} sigmas[last]={:.4} img2img={}",
2292 use_cfg,
2293 req.guidance,
2294 image_seq_len,
2295 scheduler.sigmas[0],
2296 scheduler.sigmas[scheduler.sigmas.len() - 1],
2297 is_img2img,
2298 );
2299 }
2300
2301 let use_batched_cfg = use_cfg && transformer.supports_cfg_batching();
2302 if use_cfg && !use_batched_cfg {
2303 self.base.progress.info(
2304 "Low-memory quantized Qwen CUDA path detected — disabling CFG batching to reduce peak CUDA memory",
2305 );
2306 }
2307
2308 let (batched_hs, batched_mask) = if use_batched_cfg {
2311 let hs = Tensor::cat(&[&encoder_hidden_states, uncond_hs.as_ref().unwrap()], 0)?;
2312 let mask = Tensor::cat(&[&encoder_attention_mask, uncond_mask.as_ref().unwrap()], 0)?;
2313 (hs, mask)
2314 } else {
2315 (
2316 encoder_hidden_states.clone(),
2317 encoder_attention_mask.clone(),
2318 )
2319 };
2320
2321 for step in 0..num_steps {
2322 let step_start = Instant::now();
2323 let t = scheduler.current_timestep();
2324 let noise_pred = if use_cfg {
2325 let (cond_pred, uncond_pred) = if use_batched_cfg {
2326 let t_tensor =
2327 Tensor::from_vec(vec![t as f32; 2], (2,), &device)?.to_dtype(dtype)?;
2328 let batched_latents = Tensor::cat(&[&latents, &latents], 0)?;
2329 let batched_pred = transformer.forward(
2330 &batched_latents,
2331 &t_tensor,
2332 &batched_hs,
2333 &batched_mask,
2334 )?;
2335 (batched_pred.narrow(0, 0, 1)?, batched_pred.narrow(0, 1, 1)?)
2336 } else {
2337 let t_tensor =
2338 Tensor::from_vec(vec![t as f32], (1,), &device)?.to_dtype(dtype)?;
2339 (
2340 transformer.forward(
2341 &latents,
2342 &t_tensor,
2343 &encoder_hidden_states,
2344 &encoder_attention_mask,
2345 )?,
2346 transformer.forward(
2347 &latents,
2348 &t_tensor,
2349 uncond_hs.as_ref().unwrap(),
2350 uncond_mask.as_ref().unwrap(),
2351 )?,
2352 )
2353 };
2354 if step == 0 {
2355 Self::debug_tensor_stats("cond_pred[0]", &cond_pred);
2356 Self::debug_tensor_stats("uncond_pred[0]", &uncond_pred);
2357 }
2358 let cond_f32 = cond_pred.to_dtype(DType::F32)?;
2361 let uncond_f32 = uncond_pred.to_dtype(DType::F32)?;
2362 let comb = (&uncond_f32 + ((&cond_f32 - &uncond_f32)? * req.guidance)?)?;
2363 let cond_norm = cond_f32.sqr()?.sum_keepdim(1)?.sqrt()?;
2364 let comb_norm = comb.sqr()?.sum_keepdim(1)?.sqrt()?.clamp(1e-8, f64::MAX)?;
2365 let rescaled = comb.broadcast_mul(&(cond_norm / comb_norm)?)?;
2366 rescaled.to_dtype(dtype)?
2367 } else {
2368 let t_tensor = Tensor::from_vec(vec![t as f32], (1,), &device)?.to_dtype(dtype)?;
2369 transformer.forward(
2370 &latents,
2371 &t_tensor,
2372 &encoder_hidden_states,
2373 &encoder_attention_mask,
2374 )?
2375 };
2376 if step == 0 || step == num_steps / 2 || step == num_steps - 1 {
2377 Self::debug_tensor_stats(&format!("noise_pred[{step}]"), &noise_pred);
2378 Self::debug_tensor_stats(&format!("latents[{step}]"), &latents);
2379 }
2380 if step == 0 {
2381 Self::validate_qwen_tensor_boundary("noise_pred[0]", &noise_pred)?;
2382 }
2383 latents = scheduler.step(&noise_pred, &latents)?;
2384 if step == num_steps - 1 {
2385 Self::validate_qwen_tensor_boundary("latents_final", &latents)?;
2386 }
2387
2388 if let Some(ref ctx) = inpaint_ctx {
2390 latents = crate::img2img::apply_flow_match_inpaint(
2391 &latents,
2392 ctx,
2393 scheduler.sigmas[step + 1],
2394 )?;
2395 }
2396
2397 if std::env::var_os("MOLD_QWEN_DEBUG").is_some() {
2398 let n = latents
2399 .ne(&latents)?
2400 .to_dtype(candle_core::DType::U32)?
2401 .sum_all()?
2402 .to_scalar::<u32>()?;
2403 if n > 0 {
2404 eprintln!(
2405 "[qwen-nan] NaN in latents AFTER step {step}: {n}/{}",
2406 latents.elem_count()
2407 );
2408 }
2409 }
2410 self.base.progress.emit(ProgressEvent::DenoiseStep {
2411 step: step + 1,
2412 total: num_steps,
2413 elapsed: step_start.elapsed(),
2414 });
2415 }
2416
2417 self.base
2418 .progress
2419 .stage_done(&denoise_label, denoise_start.elapsed());
2420
2421 drop(transformer);
2423 drop(encoder_hidden_states);
2424 drop(encoder_attention_mask);
2425 drop(uncond_hs);
2426 drop(uncond_mask);
2427 device.synchronize()?;
2428 self.base.progress.info("Freed Qwen-Image transformer");
2429
2430 if let Some(status) = memory_status_string() {
2432 self.base.progress.info(&status);
2433 }
2434
2435 let free_for_vae = usable_free_vram_bytes(self.base.gpu_ordinal).unwrap_or(0);
2437 let vae_on_gpu = should_use_gpu(
2438 device.is_cuda(),
2439 device.is_metal(),
2440 free_for_vae,
2441 VAE_DECODE_VRAM_THRESHOLD,
2442 );
2443 let vae_ref =
2444 effective_device_ref(self.pending_placement.as_ref(), |adv| Some(adv.vae), false);
2445 let vae_device = crate::device::resolve_device(Some(vae_ref), || {
2446 Ok(if vae_on_gpu {
2447 device.clone()
2448 } else {
2449 Device::Cpu
2450 })
2451 })?;
2452 let vae_on_gpu = !vae_device.is_cpu();
2453 let vae_dtype = DType::F32;
2456 let vae_device_label = if vae_on_gpu { "GPU" } else { "CPU" };
2457
2458 let vae_label = format!("Loading Qwen-Image VAE ({}, F32)", vae_device_label);
2459 self.base.progress.stage_start(&vae_label);
2460 let vae_start = Instant::now();
2461 let vae = self.load_vae(&vae_device, vae_dtype)?;
2462 self.base
2463 .progress
2464 .stage_done(&vae_label, vae_start.elapsed());
2465
2466 self.base.progress.stage_start("VAE decode");
2467 let vae_decode_start = Instant::now();
2468 let free_for_decode = usable_free_vram_bytes(self.base.gpu_ordinal).unwrap_or(0);
2469 let prefer_tiled = Self::should_proactively_tile_vae_decode(
2470 req.width,
2471 req.height,
2472 vae_device.is_cuda(),
2473 free_for_decode,
2474 );
2475
2476 let image = Self::decode_vae_with_fallback(
2477 &latents,
2478 &vae,
2479 &vae_device,
2480 &device,
2481 &self.base.progress,
2482 prefer_tiled,
2483 || self.load_vae(&Device::Cpu, DType::F32),
2484 )?;
2485 Self::validate_qwen_tensor_boundary("image_pre_postprocess", &image)?;
2486 Self::debug_tensor_stats("image_pre_postprocess", &image);
2487 let image = postprocess_image(&image)?;
2488 let post_stats = Self::validate_qwen_tensor_boundary("image_postprocess", &image)?;
2489 Self::debug_tensor_stats("image_postprocess", &image);
2490 let image = image.i(0)?;
2491 if Self::near_black_image_stats(post_stats) {
2492 self.base.progress.info(
2493 "Qwen diagnostic: decoded image is near-black after VAE postprocess; inspect MOLD_QWEN_DEBUG tensor stats to separate denoise math from VAE decode",
2494 );
2495 tracing::warn!(
2496 min = post_stats.min,
2497 max = post_stats.max,
2498 mean = post_stats.mean,
2499 "Qwen decoded image is near-black after VAE postprocess"
2500 );
2501 }
2502
2503 self.base
2504 .progress
2505 .stage_done("VAE decode", vae_decode_start.elapsed());
2506
2507 let output_metadata = build_output_metadata(req, seed, None);
2508 let image_bytes = encode_image(
2509 &image,
2510 req.resolved_output_format(),
2511 req.width,
2512 req.height,
2513 output_metadata.as_ref(),
2514 )?;
2515
2516 let generation_time_ms = start.elapsed().as_millis() as u64;
2517 tracing::info!(
2518 generation_time_ms,
2519 seed,
2520 "sequential Qwen-Image generation complete"
2521 );
2522
2523 Ok(GenerateResponse {
2524 images: vec![ImageData {
2525 data: image_bytes,
2526 format: req.resolved_output_format(),
2527 width: req.width,
2528 height: req.height,
2529 index: 0,
2530 }],
2531 generation_time_ms,
2532 model: req.model.clone(),
2533 seed_used: seed,
2534 video: None,
2535 gpu: None,
2536 })
2537 }
2538
2539 fn generate_edit_loaded(&mut self, req: &GenerateRequest) -> Result<GenerateResponse> {
2540 let progress = &self.base.progress;
2541 let start = Instant::now();
2542
2543 let loaded_ref = self
2544 .base
2545 .loaded
2546 .as_ref()
2547 .ok_or_else(|| anyhow::anyhow!("model not loaded"))?;
2548 let needs_reload = loaded_ref.transformer.is_none();
2549 if needs_reload {
2550 let mut loaded_mut = self
2551 .base
2552 .loaded
2553 .take()
2554 .ok_or_else(|| anyhow::anyhow!("model not loaded"))?;
2555 progress.stage_start("Reloading Qwen-Image transformer");
2556 let reload_start = Instant::now();
2557 self.reload_transformer(&mut loaded_mut, req.width as usize, req.height as usize)?;
2558 progress.stage_done("Reloading Qwen-Image transformer", reload_start.elapsed());
2559 self.base.loaded = Some(loaded_mut);
2560 }
2561
2562 let is_edit_family = self.is_edit_family();
2563 let loaded = self
2564 .base
2565 .loaded
2566 .as_mut()
2567 .ok_or_else(|| anyhow::anyhow!("model not loaded"))?;
2568 let seed = req.seed.unwrap_or_else(rand_seed);
2569 let width = req.width as usize;
2570 let height = req.height as usize;
2571 let edit_images = req
2572 .edit_images
2573 .as_ref()
2574 .ok_or_else(|| anyhow::anyhow!("qwen-image-edit requires edit_images"))?;
2575 let use_cfg = req.guidance > 1.0;
2576 let negative_prompt = req
2577 .negative_prompt
2578 .as_deref()
2579 .unwrap_or(QWEN_EMPTY_NEGATIVE_PROMPT);
2580 let formatted_prompt = Self::qwen_image_edit_prompt(&req.prompt, edit_images.len());
2581 let formatted_negative = Self::qwen_image_edit_prompt(negative_prompt, edit_images.len());
2582
2583 tracing::info!(
2584 prompt = %req.prompt,
2585 seed,
2586 width,
2587 height,
2588 steps = req.steps,
2589 edit_images = edit_images.len(),
2590 "starting Qwen-Image edit generation"
2591 );
2592
2593 if loaded.text_encoder.model.is_none() {
2594 let label = if loaded.text_encoder.is_parked() {
2595 "Unparking Qwen2.5 encoder (CPU→GPU)"
2596 } else {
2597 "Reloading Qwen2.5 encoder"
2598 };
2599 progress.stage_start(label);
2600 let reload_start = Instant::now();
2601 if loaded.text_encoder.is_parked() {
2602 loaded.text_encoder.unpark_to_gpu(progress)?;
2603 } else {
2604 loaded.text_encoder.reload(progress)?;
2605 }
2606 progress.stage_done(label, reload_start.elapsed());
2607 }
2608
2609 progress.stage_start("Encoding prompt (Qwen2.5 edit)");
2610 let encode_start = Instant::now();
2611 let (encoder_hidden_states, encoder_attention_mask, _) =
2612 loaded.text_encoder.encode_formatted_multimodal(
2613 &formatted_prompt,
2614 edit_images,
2615 &loaded.device,
2616 loaded.dtype,
2617 )?;
2618 progress.stage_done("Encoding prompt (Qwen2.5 edit)", encode_start.elapsed());
2619 let (encoder_hidden_states, encoder_attention_mask, uncond_hs, uncond_mask) = if use_cfg {
2620 progress.stage_start("Encoding negative prompt (Qwen2.5 edit)");
2621 let neg_start = Instant::now();
2622 let (hs, mask, _) = loaded.text_encoder.encode_formatted_multimodal(
2623 &formatted_negative,
2624 edit_images,
2625 &loaded.device,
2626 loaded.dtype,
2627 )?;
2628 progress.stage_done(
2629 "Encoding negative prompt (Qwen2.5 edit)",
2630 neg_start.elapsed(),
2631 );
2632 let ((cond_hs, cond_mask), (neg_hs, neg_mask)) = align_cfg_conditioning(
2633 &encoder_hidden_states,
2634 &encoder_attention_mask,
2635 &hs,
2636 &mask,
2637 )?;
2638 (cond_hs, cond_mask, Some(neg_hs), Some(neg_mask))
2639 } else {
2640 (encoder_hidden_states, encoder_attention_mask, None, None)
2641 };
2642
2643 let drop_text_encoder = is_edit_family || loaded.text_encoder.on_gpu;
2644 if drop_text_encoder {
2645 let park_mode = crate::device::keep_te_in_ram()
2646 && !loaded.device.is_metal()
2647 && !loaded.text_encoder.is_quantized;
2648 if park_mode {
2649 loaded.text_encoder.park_to_cpu()?;
2650 tracing::info!(
2651 on_gpu = loaded.text_encoder.on_gpu,
2652 "Qwen2.5 text encoder parked to CPU host RAM after edit conditioning"
2653 );
2654 } else {
2655 loaded.text_encoder.drop_weights();
2656 tracing::info!(
2657 on_gpu = loaded.text_encoder.on_gpu,
2658 "Qwen2.5 text encoder dropped after edit conditioning"
2659 );
2660 }
2661 }
2662
2663 let mut packed_input_storage = Vec::with_capacity(edit_images.len());
2664 let mut img_shapes = vec![(1usize, height / 16, width / 16)];
2665 progress.stage_start("Encoding edit images (VAE)");
2666 let encode_start = Instant::now();
2667 for image_bytes in edit_images {
2668 let (vae_width, vae_height) =
2669 Self::qwen_image_edit_image_dims(image_bytes, QWEN_IMAGE_EDIT_VAE_AREA)?;
2670 let encoded = Self::encode_vae_with_fallback(
2671 image_bytes,
2672 vae_width,
2673 vae_height,
2674 &loaded.vae,
2675 &loaded.vae_device,
2676 &loaded.device,
2677 progress,
2678 || {
2679 Ok(QwenImageVae::load(
2680 &loaded.vae_path,
2681 &Device::Cpu,
2682 DType::F32,
2683 progress,
2684 )?)
2685 },
2686 )?
2687 .to_device(&loaded.device)?
2688 .to_dtype(loaded.dtype)?;
2689 img_shapes.push((1, encoded.dim(2)? / 2, encoded.dim(3)? / 2));
2690 packed_input_storage.push(Self::pack_latents_4d(&encoded)?);
2691 }
2692 progress.stage_done("Encoding edit images (VAE)", encode_start.elapsed());
2693
2694 let packed_inputs = if packed_input_storage.is_empty() {
2695 None
2696 } else {
2697 let tensors = packed_input_storage.iter().collect::<Vec<_>>();
2698 Some(Tensor::cat(&tensors, 1)?)
2699 };
2700
2701 let noise = crate::engine::seeded_randn(
2702 seed,
2703 &[1, 16, height / 8, width / 8],
2704 &loaded.device,
2705 loaded.dtype,
2706 )?;
2707 let mut scheduler =
2708 QwenImageScheduler::new(req.steps as usize, (height / 16) * (width / 16));
2709 let num_steps = scheduler.num_steps();
2710 let mut latents = Self::pack_latents_4d(&(noise * scheduler.initial_sigma())?)?;
2711 let output_seq_len = latents.dim(1)?;
2712
2713 let denoise_label = format!("Denoising edit ({} steps)", num_steps);
2714 progress.stage_start(&denoise_label);
2715 let denoise_start = Instant::now();
2716
2717 {
2718 let transformer = loaded
2719 .transformer
2720 .as_ref()
2721 .expect("transformer must be loaded for denoising");
2722 let use_batched_cfg = use_cfg && transformer.supports_cfg_batching();
2723 let (batched_hs, batched_mask) = if use_batched_cfg {
2724 let hs = Tensor::cat(&[&encoder_hidden_states, uncond_hs.as_ref().unwrap()], 0)?;
2725 let mask =
2726 Tensor::cat(&[&encoder_attention_mask, uncond_mask.as_ref().unwrap()], 0)?;
2727 (hs, mask)
2728 } else {
2729 (
2730 encoder_hidden_states.clone(),
2731 encoder_attention_mask.clone(),
2732 )
2733 };
2734
2735 for step in 0..num_steps {
2736 let step_start = Instant::now();
2737 let t = scheduler.current_timestep();
2738 let timestep = if use_batched_cfg {
2739 Tensor::from_vec(vec![t as f32; 2], (2,), &loaded.device)?
2740 .to_dtype(loaded.dtype)?
2741 } else {
2742 Tensor::from_vec(vec![t as f32], (1,), &loaded.device)?
2743 .to_dtype(loaded.dtype)?
2744 };
2745
2746 let latent_model_input = if let Some(ref packed_inputs) = packed_inputs {
2747 Tensor::cat(&[&latents, packed_inputs], 1)?
2748 } else {
2749 latents.clone()
2750 };
2751
2752 let noise_pred = if use_cfg {
2753 let (cond_pred, uncond_pred) = if use_batched_cfg {
2754 let batched_input =
2755 Tensor::cat(&[&latent_model_input, &latent_model_input], 0)?;
2756 let pred = transformer.forward_packed(
2757 &batched_input,
2758 ×tep,
2759 &batched_hs,
2760 &batched_mask,
2761 &img_shapes,
2762 )?;
2763 (
2764 pred.narrow(0, 0, 1)?.narrow(1, 0, output_seq_len)?,
2765 pred.narrow(0, 1, 1)?.narrow(1, 0, output_seq_len)?,
2766 )
2767 } else {
2768 (
2769 transformer
2770 .forward_packed(
2771 &latent_model_input,
2772 ×tep,
2773 &encoder_hidden_states,
2774 &encoder_attention_mask,
2775 &img_shapes,
2776 )?
2777 .narrow(1, 0, output_seq_len)?,
2778 transformer
2779 .forward_packed(
2780 &latent_model_input,
2781 ×tep,
2782 uncond_hs.as_ref().unwrap(),
2783 uncond_mask.as_ref().unwrap(),
2784 &img_shapes,
2785 )?
2786 .narrow(1, 0, output_seq_len)?,
2787 )
2788 };
2789
2790 let cond_f32 = cond_pred.to_dtype(DType::F32)?;
2791 let uncond_f32 = uncond_pred.to_dtype(DType::F32)?;
2792 let comb = (&uncond_f32 + ((&cond_f32 - &uncond_f32)? * req.guidance)?)?;
2793 let cond_norm = cond_f32.sqr()?.sum_keepdim(D::Minus1)?.sqrt()?;
2794 let comb_norm = comb
2795 .sqr()?
2796 .sum_keepdim(D::Minus1)?
2797 .sqrt()?
2798 .clamp(1e-8, f64::MAX)?;
2799 comb.broadcast_mul(&(cond_norm / comb_norm)?)?
2800 .to_dtype(loaded.dtype)?
2801 } else {
2802 transformer
2803 .forward_packed(
2804 &latent_model_input,
2805 ×tep,
2806 &encoder_hidden_states,
2807 &encoder_attention_mask,
2808 &img_shapes,
2809 )?
2810 .narrow(1, 0, output_seq_len)?
2811 };
2812
2813 latents = scheduler.step(&noise_pred, &latents)?;
2814 progress.emit(ProgressEvent::DenoiseStep {
2815 step: step + 1,
2816 total: num_steps,
2817 elapsed: step_start.elapsed(),
2818 });
2819 }
2820 }
2821
2822 progress.stage_done(&denoise_label, denoise_start.elapsed());
2823
2824 let latents = Self::unpack_latents_packed(&latents, height / 8, width / 8)?;
2825 let free_for_decode = usable_free_vram_bytes(self.base.gpu_ordinal).unwrap_or(0);
2826 let prefer_tiled = Self::should_proactively_tile_vae_decode(
2827 req.width,
2828 req.height,
2829 loaded.vae_device.is_cuda(),
2830 free_for_decode,
2831 );
2832 let image = Self::decode_vae_with_fallback(
2833 &latents,
2834 &loaded.vae,
2835 &loaded.vae_device,
2836 &loaded.device,
2837 progress,
2838 prefer_tiled,
2839 || {
2840 Ok(QwenImageVae::load(
2841 &loaded.vae_path,
2842 &Device::Cpu,
2843 DType::F32,
2844 progress,
2845 )?)
2846 },
2847 )?;
2848 let image = postprocess_image(&image)?.i(0)?;
2849 let output_metadata = build_output_metadata(req, seed, None);
2850 let image_bytes = encode_image(
2851 &image,
2852 req.resolved_output_format(),
2853 req.width,
2854 req.height,
2855 output_metadata.as_ref(),
2856 )?;
2857
2858 Ok(GenerateResponse {
2859 images: vec![ImageData {
2860 data: image_bytes,
2861 format: req.resolved_output_format(),
2862 width: req.width,
2863 height: req.height,
2864 index: 0,
2865 }],
2866 generation_time_ms: start.elapsed().as_millis() as u64,
2867 model: req.model.clone(),
2868 seed_used: seed,
2869 video: None,
2870 gpu: None,
2871 })
2872 }
2873}
2874
2875impl QwenImageEngine {
2876 fn generate_inner(&mut self, req: &GenerateRequest) -> Result<GenerateResponse> {
2877 if req.scheduler.is_some() {
2878 tracing::warn!(
2879 "scheduler selection not supported for Qwen-Image (flow-matching), ignoring"
2880 );
2881 }
2882
2883 if self.is_edit_family() {
2884 let sequential = self.base.load_strategy == LoadStrategy::Sequential;
2885 if sequential && self.base.loaded.is_none() {
2886 let original = self.base.load_strategy;
2887 self.base.load_strategy = LoadStrategy::Eager;
2888 let load_result = self.load();
2889 self.base.load_strategy = original;
2890 load_result?;
2891 }
2892 if self.base.loaded.is_none() {
2893 bail!("model not loaded -- call load() first");
2894 }
2895 let result = self.generate_edit_loaded(req);
2896 if sequential {
2897 self.unload();
2898 }
2899 return result;
2900 }
2901
2902 if self.base.load_strategy == LoadStrategy::Sequential {
2904 return self.generate_sequential(req);
2905 }
2906
2907 if self.base.loaded.is_none() {
2909 bail!("model not loaded -- call load() first");
2910 }
2911
2912 let progress = &self.base.progress;
2913 let gpu_ordinal = self.base.gpu_ordinal;
2914 let start = Instant::now();
2915
2916 let loaded_ref = self
2918 .base
2919 .loaded
2920 .as_ref()
2921 .ok_or_else(|| anyhow::anyhow!("model not loaded"))?;
2922 let needs_reload = loaded_ref.transformer.is_none();
2923 if needs_reload {
2924 let mut loaded_mut = self
2925 .base
2926 .loaded
2927 .take()
2928 .ok_or_else(|| anyhow::anyhow!("model not loaded"))?;
2929 progress.stage_start("Reloading Qwen-Image transformer");
2930 let reload_start = Instant::now();
2931 self.reload_transformer(&mut loaded_mut, req.width as usize, req.height as usize)?;
2932 progress.stage_done("Reloading Qwen-Image transformer", reload_start.elapsed());
2933 self.base.loaded = Some(loaded_mut);
2934 }
2935
2936 let loaded = self
2937 .base
2938 .loaded
2939 .as_mut()
2940 .ok_or_else(|| anyhow::anyhow!("model not loaded"))?;
2941 let seed = req.seed.unwrap_or_else(rand_seed);
2942
2943 let width = req.width as usize;
2944 let height = req.height as usize;
2945
2946 tracing::info!(
2947 prompt = %req.prompt,
2948 seed, width, height,
2949 steps = req.steps,
2950 "starting Qwen-Image generation"
2951 );
2952
2953 let use_cfg = req.guidance > 1.0;
2954 let prompt_key = prompt_text_key(&req.prompt);
2955 let uncond_key = prompt_text_key(QWEN_EMPTY_NEGATIVE_PROMPT);
2956 let prompt_cached = self
2957 .prompt_cache
2958 .lock()
2959 .expect("cache poisoned")
2960 .get_cloned(&prompt_key);
2961 let uncond_cached = if use_cfg {
2962 self.prompt_cache
2963 .lock()
2964 .expect("cache poisoned")
2965 .get_cloned(&uncond_key)
2966 } else {
2967 None
2968 };
2969 let both_cached = prompt_cached.is_some() && (!use_cfg || uncond_cached.is_some());
2970
2971 let (encoder_hidden_states, encoder_attention_mask, uncond_hs, uncond_mask) = if both_cached
2972 {
2973 let cached = prompt_cached.expect("prompt cache unexpectedly missing");
2974 progress.cache_hit("prompt conditioning");
2975 let (hs, mask) = cached.restore(&loaded.device, loaded.dtype)?;
2976 let (u_hs, u_mask) = if use_cfg {
2977 progress.cache_hit("unconditional conditioning");
2978 let ucached =
2979 uncond_cached.expect("unconditional prompt cache unexpectedly missing");
2980 let (u_hs, u_mask) = ucached.restore(&loaded.device, loaded.dtype)?;
2981 (Some(u_hs), Some(u_mask))
2982 } else {
2983 (None, None)
2984 };
2985 (hs, mask, u_hs, u_mask)
2986 } else {
2987 if loaded.text_encoder.model.is_none() {
2988 let label = if loaded.text_encoder.is_parked() {
2989 "Unparking Qwen2.5 encoder (CPU→GPU)"
2990 } else {
2991 "Reloading Qwen2.5 encoder"
2992 };
2993 progress.stage_start(label);
2994 let reload_start = Instant::now();
2995 if loaded.text_encoder.is_parked() {
2996 loaded.text_encoder.unpark_to_gpu(progress)?;
2997 } else {
2998 loaded.text_encoder.reload(progress)?;
2999 }
3000 progress.stage_done(label, reload_start.elapsed());
3001 }
3002
3003 let (hs, mask) = Self::encode_prompt_cached(
3004 progress,
3005 &self.prompt_cache,
3006 &mut loaded.text_encoder,
3007 &req.prompt,
3008 &loaded.device,
3009 loaded.dtype,
3010 )?;
3011
3012 let (u_hs, u_mask) = if use_cfg {
3013 let (hs, mask) = Self::encode_prompt_cached(
3014 progress,
3015 &self.prompt_cache,
3016 &mut loaded.text_encoder,
3017 QWEN_EMPTY_NEGATIVE_PROMPT,
3018 &loaded.device,
3019 loaded.dtype,
3020 )?;
3021 (Some(hs), Some(mask))
3022 } else {
3023 (None, None)
3024 };
3025
3026 (hs, mask, u_hs, u_mask)
3027 };
3028
3029 let (encoder_hidden_states, encoder_attention_mask, uncond_hs, uncond_mask) = if use_cfg {
3030 let ((cond_hs, cond_mask), (neg_hs, neg_mask)) = align_cfg_conditioning(
3031 &encoder_hidden_states,
3032 &encoder_attention_mask,
3033 uncond_hs.as_ref().expect("unconditional prompt missing"),
3034 uncond_mask.as_ref().expect("unconditional mask missing"),
3035 )?;
3036 (cond_hs, cond_mask, Some(neg_hs), Some(neg_mask))
3037 } else {
3038 (
3039 encoder_hidden_states,
3040 encoder_attention_mask,
3041 uncond_hs,
3042 uncond_mask,
3043 )
3044 };
3045
3046 if loaded.text_encoder.on_gpu {
3048 let free_after_encode = usable_free_vram_bytes(gpu_ordinal).unwrap_or(0);
3049 let required_for_residency = Self::qwen2_hot_text_encoder_required_vram(
3050 req.width,
3051 req.height,
3052 if req.guidance > 1.0 { 2 } else { 1 },
3053 loaded.dtype,
3054 );
3055 let action =
3056 Self::qwen2_text_encoder_post_encode_action(Qwen2TextEncoderResidencyInput {
3057 on_gpu: loaded.text_encoder.on_gpu,
3058 is_quantized: loaded.text_encoder.is_quantized,
3059 is_metal: loaded.device.is_metal(),
3060 keep_te_ram: crate::device::keep_te_in_ram(),
3061 prompt_cache_miss: !both_cached,
3062 transformer_resident: loaded.transformer.is_some(),
3063 free_vram_bytes: free_after_encode,
3064 required_vram_bytes: required_for_residency,
3065 });
3066 match action {
3067 Qwen2TextEncoderPostEncodeAction::KeepGpu => {
3068 progress.info(&format!(
3069 "Keeping Qwen2.5 text encoder on GPU for hot prompt-cache misses ({} free >= {} reserve)",
3070 fmt_gb(free_after_encode),
3071 fmt_gb(required_for_residency)
3072 ));
3073 tracing::info!(
3074 free_vram_bytes = free_after_encode,
3075 required_vram_bytes = required_for_residency,
3076 is_quantized = loaded.text_encoder.is_quantized,
3077 "Qwen2.5 text encoder kept on GPU after cache miss"
3078 );
3079 }
3080 Qwen2TextEncoderPostEncodeAction::ParkCpu => {
3081 loaded.text_encoder.park_to_cpu()?;
3082 progress.info(&format!(
3083 "Parked Qwen2.5 text encoder to CPU host RAM before denoise ({} free < {} reserve)",
3084 fmt_gb(free_after_encode),
3085 fmt_gb(required_for_residency)
3086 ));
3087 tracing::info!("Qwen2.5 text encoder parked to CPU host RAM");
3088 }
3089 Qwen2TextEncoderPostEncodeAction::Drop => {
3090 loaded.text_encoder.drop_weights();
3091 progress.info(&format!(
3092 "Dropped Qwen2.5 text encoder before denoise ({} free < {} reserve or cache hit)",
3093 fmt_gb(free_after_encode),
3094 fmt_gb(required_for_residency)
3095 ));
3096 tracing::info!("Qwen2.5 text encoder dropped from GPU");
3097 }
3098 }
3099 }
3100
3101 let vae_downsample = 8;
3103 let latent_h = height / vae_downsample;
3104 let latent_w = width / vae_downsample;
3105 let is_img2img = req.source_image.is_some();
3106
3107 let (prepared_img2img_latents, inpaint_ctx) =
3109 if let Some(ref source_bytes) = req.source_image {
3110 let encoded = Self::encode_vae_with_fallback(
3111 source_bytes,
3112 req.width,
3113 req.height,
3114 &loaded.vae,
3115 &loaded.vae_device,
3116 &loaded.device,
3117 progress,
3118 || {
3119 Ok(QwenImageVae::load(
3120 &loaded.vae_path,
3121 &Device::Cpu,
3122 DType::F32,
3123 progress,
3124 )?)
3125 },
3126 )?;
3127 let encoded = encoded.to_device(&loaded.device)?.to_dtype(loaded.dtype)?;
3128 let start_sigma = QwenImageScheduler::new_img2img(
3129 req.steps as usize,
3130 image_seq_len(latent_h, latent_w, loaded.transformer_cfg.patch_size),
3131 req.strength,
3132 )
3133 .0
3134 .initial_sigma();
3135 let prepared = crate::img2img::prepare_flow_match_img2img(
3136 &encoded,
3137 seed,
3138 &[1, 16, latent_h, latent_w],
3139 start_sigma,
3140 req.mask_image.as_deref(),
3141 latent_h,
3142 latent_w,
3143 &loaded.device,
3144 loaded.dtype,
3145 )?;
3146
3147 (Some(prepared.initial_latents), prepared.inpaint_ctx)
3148 } else {
3149 (None, None)
3150 };
3151
3152 let image_seq_len = image_seq_len(latent_h, latent_w, loaded.transformer_cfg.patch_size);
3154 let (mut scheduler, num_steps) = if is_img2img {
3155 QwenImageScheduler::new_img2img(req.steps as usize, image_seq_len, req.strength)
3156 } else {
3157 let sched = QwenImageScheduler::new(req.steps as usize, image_seq_len);
3158 let n = sched.num_steps();
3159 (sched, n)
3160 };
3161
3162 let mut latents = if let Some(initial) = &prepared_img2img_latents {
3164 initial.clone()
3165 } else {
3166 let noise = crate::engine::seeded_randn(
3167 seed,
3168 &[1, 16, latent_h, latent_w],
3169 &loaded.device,
3170 loaded.dtype,
3171 )?;
3172 (noise * scheduler.initial_sigma())?
3173 };
3174
3175 let denoise_label = format!("Denoising ({} steps)", num_steps);
3177 progress.stage_start(&denoise_label);
3178 let denoise_start = Instant::now();
3179
3180 {
3181 let transformer = loaded
3182 .transformer
3183 .as_ref()
3184 .expect("transformer must be loaded for denoising");
3185
3186 let use_batched_cfg = use_cfg && transformer.supports_cfg_batching();
3187 if use_cfg && !use_batched_cfg {
3188 progress.info(
3189 "Low-memory quantized Qwen CUDA path detected — disabling CFG batching to reduce peak CUDA memory",
3190 );
3191 }
3192
3193 let (batched_hs, batched_mask) = if use_batched_cfg {
3196 let hs = Tensor::cat(&[&encoder_hidden_states, uncond_hs.as_ref().unwrap()], 0)?;
3197 let mask =
3198 Tensor::cat(&[&encoder_attention_mask, uncond_mask.as_ref().unwrap()], 0)?;
3199 (hs, mask)
3200 } else {
3201 (
3202 encoder_hidden_states.clone(),
3203 encoder_attention_mask.clone(),
3204 )
3205 };
3206
3207 for step in 0..num_steps {
3208 let step_start = Instant::now();
3209 let t = scheduler.current_timestep();
3210 let noise_pred = if use_cfg {
3211 let (cond_pred, uncond_pred) = if use_batched_cfg {
3212 let t_tensor = Tensor::from_vec(vec![t as f32; 2], (2,), &loaded.device)?
3213 .to_dtype(loaded.dtype)?;
3214 let batched_latents = Tensor::cat(&[&latents, &latents], 0)?;
3215 let batched_pred = transformer.forward(
3216 &batched_latents,
3217 &t_tensor,
3218 &batched_hs,
3219 &batched_mask,
3220 )?;
3221 (batched_pred.narrow(0, 0, 1)?, batched_pred.narrow(0, 1, 1)?)
3222 } else {
3223 let t_tensor = Tensor::from_vec(vec![t as f32], (1,), &loaded.device)?
3224 .to_dtype(loaded.dtype)?;
3225 (
3226 transformer.forward(
3227 &latents,
3228 &t_tensor,
3229 &encoder_hidden_states,
3230 &encoder_attention_mask,
3231 )?,
3232 transformer.forward(
3233 &latents,
3234 &t_tensor,
3235 uncond_hs.as_ref().unwrap(),
3236 uncond_mask.as_ref().unwrap(),
3237 )?,
3238 )
3239 };
3240 let cond_f32 = cond_pred.to_dtype(DType::F32)?;
3242 let uncond_f32 = uncond_pred.to_dtype(DType::F32)?;
3243 let comb = (&uncond_f32 + ((&cond_f32 - &uncond_f32)? * req.guidance)?)?;
3244 let cond_norm = cond_f32.sqr()?.sum_keepdim(1)?.sqrt()?;
3245 let comb_norm = comb.sqr()?.sum_keepdim(1)?.sqrt()?.clamp(1e-8, f64::MAX)?;
3246 let rescaled = comb.broadcast_mul(&(cond_norm / comb_norm)?)?;
3247 rescaled.to_dtype(loaded.dtype)?
3248 } else {
3249 let t_tensor = Tensor::from_vec(vec![t as f32], (1,), &loaded.device)?
3250 .to_dtype(loaded.dtype)?;
3251 transformer.forward(
3252 &latents,
3253 &t_tensor,
3254 &encoder_hidden_states,
3255 &encoder_attention_mask,
3256 )?
3257 };
3258 if step == 0 || step == num_steps / 2 || step == num_steps - 1 {
3259 Self::debug_tensor_stats(&format!("noise_pred[{step}]"), &noise_pred);
3260 Self::debug_tensor_stats(&format!("latents[{step}]"), &latents);
3261 }
3262 if step == 0 {
3263 Self::validate_qwen_tensor_boundary("noise_pred[0]", &noise_pred)?;
3264 }
3265 latents = scheduler.step(&noise_pred, &latents)?;
3266 if step == num_steps - 1 {
3267 Self::validate_qwen_tensor_boundary("latents_final", &latents)?;
3268 }
3269
3270 if let Some(ref ctx) = inpaint_ctx {
3272 latents = crate::img2img::apply_flow_match_inpaint(
3273 &latents,
3274 ctx,
3275 scheduler.sigmas[step + 1],
3276 )?;
3277 }
3278
3279 progress.emit(ProgressEvent::DenoiseStep {
3280 step: step + 1,
3281 total: num_steps,
3282 elapsed: step_start.elapsed(),
3283 });
3284 }
3285 }
3286
3287 progress.stage_done(&denoise_label, denoise_start.elapsed());
3288
3289 drop(encoder_hidden_states);
3291 drop(encoder_attention_mask);
3292 drop(uncond_hs);
3293 drop(uncond_mask);
3294
3295 progress.stage_start("VAE decode");
3297 let vae_start = Instant::now();
3298 let free_for_decode = usable_free_vram_bytes(self.base.gpu_ordinal).unwrap_or(0);
3299 let prefer_tiled = Self::should_proactively_tile_vae_decode(
3300 req.width,
3301 req.height,
3302 loaded.vae_device.is_cuda(),
3303 free_for_decode,
3304 );
3305
3306 let keep_transformer_hot = Self::can_keep_transformer_hot_for_vae(loaded);
3308 let image = if keep_transformer_hot {
3309 match Self::decode_vae_gpu_only(
3310 &latents,
3311 &loaded.vae,
3312 &loaded.vae_device,
3313 &loaded.device,
3314 progress,
3315 prefer_tiled,
3316 ) {
3317 Ok(image) => {
3318 progress.info(
3319 "Kept quantized Qwen transformer resident across VAE decode for faster hot-path reuse",
3320 );
3321 image
3322 }
3323 Err(err) if Self::is_oom_error(&err) => {
3324 loaded.transformer = None;
3325 loaded.device.synchronize()?;
3326 progress.info(
3327 "Dropping Qwen-Image transformer after resident VAE decode OOM and retrying",
3328 );
3329 Self::decode_vae_with_fallback(
3330 &latents,
3331 &loaded.vae,
3332 &loaded.vae_device,
3333 &loaded.device,
3334 progress,
3335 prefer_tiled,
3336 || {
3337 QwenImageVae::load(&loaded.vae_path, &Device::Cpu, DType::F32, progress)
3338 .map_err(Into::into)
3339 },
3340 )?
3341 }
3342 Err(err) => return Err(err),
3343 }
3344 } else {
3345 loaded.transformer = None;
3346 loaded.device.synchronize()?;
3347 tracing::info!("Qwen-Image transformer dropped to free VRAM for VAE decode");
3348 Self::decode_vae_with_fallback(
3349 &latents,
3350 &loaded.vae,
3351 &loaded.vae_device,
3352 &loaded.device,
3353 progress,
3354 prefer_tiled,
3355 || {
3356 QwenImageVae::load(&loaded.vae_path, &Device::Cpu, DType::F32, progress)
3357 .map_err(Into::into)
3358 },
3359 )?
3360 };
3361 Self::validate_qwen_tensor_boundary("image_pre_postprocess", &image)?;
3362 Self::debug_tensor_stats("image_pre_postprocess", &image);
3363 let image = postprocess_image(&image)?;
3364 let post_stats = Self::validate_qwen_tensor_boundary("image_postprocess", &image)?;
3365 Self::debug_tensor_stats("image_postprocess", &image);
3366 let image = image.i(0)?;
3367 if Self::near_black_image_stats(post_stats) {
3368 progress.info(
3369 "Qwen diagnostic: decoded image is near-black after VAE postprocess; inspect MOLD_QWEN_DEBUG tensor stats to separate denoise math from VAE decode",
3370 );
3371 tracing::warn!(
3372 min = post_stats.min,
3373 max = post_stats.max,
3374 mean = post_stats.mean,
3375 "Qwen decoded image is near-black after VAE postprocess"
3376 );
3377 }
3378
3379 progress.stage_done("VAE decode", vae_start.elapsed());
3380
3381 let output_metadata = build_output_metadata(req, seed, None);
3383 let image_bytes = encode_image(
3384 &image,
3385 req.resolved_output_format(),
3386 req.width,
3387 req.height,
3388 output_metadata.as_ref(),
3389 )?;
3390
3391 let generation_time_ms = start.elapsed().as_millis() as u64;
3392 tracing::info!(generation_time_ms, seed, "Qwen-Image generation complete");
3393
3394 Ok(GenerateResponse {
3395 images: vec![ImageData {
3396 data: image_bytes,
3397 format: req.resolved_output_format(),
3398 width: req.width,
3399 height: req.height,
3400 index: 0,
3401 }],
3402 generation_time_ms,
3403 model: req.model.clone(),
3404 seed_used: seed,
3405 video: None,
3406 gpu: None,
3407 })
3408 }
3409}
3410
3411impl InferenceEngine for QwenImageEngine {
3412 fn generate(&mut self, req: &GenerateRequest) -> Result<GenerateResponse> {
3413 self.pending_placement = req.placement.clone();
3414 self.pending_loras = effective_loras(req);
3415 let result = self.generate_inner(req);
3416 self.pending_placement = None;
3417 self.pending_loras.clear();
3418 result
3419 }
3420
3421 fn model_name(&self) -> &str {
3422 self.base.model_name()
3423 }
3424
3425 fn is_loaded(&self) -> bool {
3426 self.base.is_loaded()
3427 }
3428
3429 fn load(&mut self) -> Result<()> {
3430 QwenImageEngine::load(self)
3431 }
3432
3433 fn unload(&mut self) {
3434 self.base.unload();
3435 clear_cache(&self.prompt_cache);
3436 }
3437
3438 fn set_on_progress(&mut self, callback: ProgressCallback) {
3439 self.base.set_on_progress(callback);
3440 }
3441
3442 fn clear_on_progress(&mut self) {
3443 self.base.clear_on_progress();
3444 }
3445
3446 fn model_paths(&self) -> Option<&mold_core::ModelPaths> {
3447 Some(&self.base.paths)
3448 }
3449}
3450
3451#[cfg(test)]
3452mod tests {
3453 use super::*;
3454 use crate::engine::LoadStrategy;
3455 use crate::shared_pool::SharedPool;
3456 use candle_core::Shape;
3457 use mold_core::ModelPaths;
3458 use safetensors::tensor::{serialize_to_file, Dtype as SafeDtype, TensorView};
3459 use std::collections::HashMap;
3460 use std::fs;
3461 use std::path::{Path, PathBuf};
3462 use std::sync::{Arc, Mutex};
3463 use std::time::{SystemTime, UNIX_EPOCH};
3464 use tokenizers::models::bpe::BPE;
3465
3466 fn temp_test_dir(prefix: &str) -> PathBuf {
3467 let suffix = SystemTime::now()
3468 .duration_since(UNIX_EPOCH)
3469 .unwrap()
3470 .as_nanos();
3471 let dir = std::env::temp_dir().join(format!("{prefix}-{}-{suffix}", std::process::id()));
3472 fs::create_dir_all(&dir).unwrap();
3473 dir
3474 }
3475
3476 fn touch(dir: &Path, name: &str) -> PathBuf {
3477 let path = dir.join(name);
3478 fs::write(&path, b"test").unwrap();
3479 path
3480 }
3481
3482 fn png_with_dimensions(width: u32, height: u32) -> Vec<u8> {
3483 let img = image::RgbImage::from_fn(width, height, |_, _| image::Rgb([255, 0, 0]));
3484 let mut buf = std::io::Cursor::new(Vec::new());
3485 image::DynamicImage::ImageRgb8(img)
3486 .write_to(&mut buf, image::ImageFormat::Png)
3487 .unwrap();
3488 buf.into_inner()
3489 }
3490
3491 fn qwen_image_model_paths(
3492 transformer: PathBuf,
3493 transformer_shards: Vec<PathBuf>,
3494 vae: PathBuf,
3495 text_tokenizer: Option<PathBuf>,
3496 ) -> ModelPaths {
3497 ModelPaths {
3498 transformer,
3499 transformer_shards,
3500 vae,
3501 spatial_upscaler: None,
3502 temporal_upscaler: None,
3503 distilled_lora: None,
3504 t5_encoder: None,
3505 clip_encoder: None,
3506 t5_tokenizer: None,
3507 clip_tokenizer: None,
3508 clip_encoder_2: None,
3509 clip_tokenizer_2: None,
3510 text_encoder_files: vec![],
3511 text_tokenizer,
3512 decoder: None,
3513 }
3514 }
3515
3516 fn resolved_text_encoder(is_gguf: bool, auto_use_gpu: bool) -> ResolvedQwen2TextEncoder {
3517 ResolvedQwen2TextEncoder {
3518 paths: vec![],
3519 vision_paths: vec![],
3520 is_gguf,
3521 variant_label: if is_gguf {
3522 "q6".to_string()
3523 } else {
3524 "bf16".to_string()
3525 },
3526 size_bytes: 0,
3527 auto_use_gpu,
3528 }
3529 }
3530
3531 fn tensor_values_u8(t: &Tensor) -> Vec<u8> {
3532 t.flatten_all()
3533 .unwrap()
3534 .to_vec1::<u8>()
3535 .expect("u8 tensor values")
3536 }
3537
3538 fn tensor_values_f32(t: &Tensor) -> Vec<f32> {
3539 t.flatten_all()
3540 .unwrap()
3541 .to_vec1::<f32>()
3542 .expect("f32 tensor values")
3543 }
3544
3545 #[test]
3546 fn safetensors_is_fp8_uses_filename_hint() {
3547 assert!(safetensors_is_fp8(Path::new(
3548 "/tmp/qwen-image-fp8.safetensors"
3549 )));
3550 assert!(!safetensors_is_fp8(Path::new(
3551 "/tmp/qwen-image.safetensors"
3552 )));
3553 }
3554
3555 #[test]
3556 fn text_encoder_is_fp8_uses_filename_hint() {
3557 assert!(text_encoder_is_fp8(&[PathBuf::from(
3558 "/tmp/qwen2-text-encoder-fp8-00001-of-00002.safetensors"
3559 )]));
3560 assert!(!text_encoder_is_fp8(&[PathBuf::from(
3561 "/tmp/qwen2-text-encoder-00001-of-00002.safetensors"
3562 )]));
3563 }
3564
3565 #[test]
3566 fn cached_prompt_conditioning_roundtrips_and_restores_mask() {
3567 let device = Device::Cpu;
3568 let hidden_states = Tensor::from_vec(
3569 vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0],
3570 Shape::from((1, 3, 2)),
3571 &device,
3572 )
3573 .unwrap();
3574 let cached = CachedPromptConditioning::from_parts(&hidden_states, 2).unwrap();
3575
3576 let (restored_hs, restored_mask) = cached.restore(&device, DType::F32).unwrap();
3577
3578 assert_eq!(
3579 tensor_values_f32(&restored_hs),
3580 tensor_values_f32(&hidden_states)
3581 );
3582 assert_eq!(tensor_values_u8(&restored_mask), vec![1, 1, 0]);
3583 }
3584
3585 #[test]
3586 fn pad_text_conditioning_keeps_original_when_target_matches() {
3587 let device = Device::Cpu;
3588 let hidden_states =
3589 Tensor::from_vec(vec![1.0f32, 2.0, 3.0, 4.0], Shape::from((1, 2, 2)), &device).unwrap();
3590 let mask = Tensor::from_vec(vec![1u8, 1], Shape::from((1, 2)), &device).unwrap();
3591
3592 let (padded_hs, padded_mask) = pad_text_conditioning(&hidden_states, &mask, 2).unwrap();
3593
3594 assert_eq!(
3595 tensor_values_f32(&padded_hs),
3596 tensor_values_f32(&hidden_states)
3597 );
3598 assert_eq!(tensor_values_u8(&padded_mask), vec![1, 1]);
3599 }
3600
3601 #[test]
3602 fn pad_text_conditioning_appends_zero_padding() {
3603 let device = Device::Cpu;
3604 let hidden_states =
3605 Tensor::from_vec(vec![1.0f32, 2.0, 3.0, 4.0], Shape::from((1, 2, 2)), &device).unwrap();
3606 let mask = Tensor::from_vec(vec![1u8, 0], Shape::from((1, 2)), &device).unwrap();
3607
3608 let (padded_hs, padded_mask) = pad_text_conditioning(&hidden_states, &mask, 4).unwrap();
3609
3610 assert_eq!(padded_hs.dims3().unwrap(), (1, 4, 2));
3611 assert_eq!(
3612 tensor_values_f32(&padded_hs),
3613 vec![1.0, 2.0, 3.0, 4.0, 0.0, 0.0, 0.0, 0.0]
3614 );
3615 assert_eq!(tensor_values_u8(&padded_mask), vec![1, 0, 0, 0]);
3616 }
3617
3618 #[test]
3619 fn pad_text_conditioning_rejects_shrinking() {
3620 let device = Device::Cpu;
3621 let hidden_states =
3622 Tensor::from_vec(vec![1.0f32, 2.0, 3.0, 4.0], Shape::from((1, 2, 2)), &device).unwrap();
3623 let mask = Tensor::from_vec(vec![1u8, 1], Shape::from((1, 2)), &device).unwrap();
3624
3625 let err = pad_text_conditioning(&hidden_states, &mask, 1).unwrap_err();
3626 assert!(err.to_string().contains("cannot shrink text conditioning"));
3627 }
3628
3629 #[test]
3630 fn align_cfg_conditioning_pads_shorter_branch_to_match_longer_one() {
3631 let device = Device::Cpu;
3632 let cond_hs = Tensor::from_vec(
3633 vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0],
3634 Shape::from((1, 3, 2)),
3635 &device,
3636 )
3637 .unwrap();
3638 let cond_mask = Tensor::from_vec(vec![1u8, 1, 1], Shape::from((1, 3)), &device).unwrap();
3639 let uncond_hs = Tensor::from_vec(
3640 vec![7.0f32, 8.0, 9.0, 10.0],
3641 Shape::from((1, 2, 2)),
3642 &device,
3643 )
3644 .unwrap();
3645 let uncond_mask = Tensor::from_vec(vec![1u8, 0], Shape::from((1, 2)), &device).unwrap();
3646
3647 let ((cond_hs, cond_mask), (uncond_hs, uncond_mask)) =
3648 align_cfg_conditioning(&cond_hs, &cond_mask, &uncond_hs, &uncond_mask).unwrap();
3649
3650 assert_eq!(cond_hs.dims3().unwrap(), (1, 3, 2));
3651 assert_eq!(uncond_hs.dims3().unwrap(), (1, 3, 2));
3652 assert_eq!(tensor_values_u8(&cond_mask), vec![1, 1, 1]);
3653 assert_eq!(tensor_values_u8(&uncond_mask), vec![1, 0, 0]);
3654 assert_eq!(
3655 tensor_values_f32(&uncond_hs),
3656 vec![7.0, 8.0, 9.0, 10.0, 0.0, 0.0]
3657 );
3658 }
3659
3660 #[test]
3661 fn qwen_image_detects_gguf_transformer() {
3662 let engine = QwenImageEngine::new(
3663 "qwen-image:q4".to_string(),
3664 ModelPaths {
3665 transformer: PathBuf::from("/tmp/qwen-image-Q4_K_S.gguf"),
3666 transformer_shards: vec![],
3667 vae: PathBuf::from("/tmp/vae.safetensors"),
3668 spatial_upscaler: None,
3669 temporal_upscaler: None,
3670 distilled_lora: None,
3671 t5_encoder: None,
3672 clip_encoder: None,
3673 t5_tokenizer: None,
3674 clip_tokenizer: None,
3675 clip_encoder_2: None,
3676 clip_tokenizer_2: None,
3677 text_encoder_files: vec![],
3678 text_tokenizer: Some(PathBuf::from("/tmp/tokenizer.json")),
3679 decoder: None,
3680 },
3681 LoadStrategy::Sequential,
3682 0,
3683 false,
3684 None,
3685 );
3686
3687 assert!(engine.detect_is_quantized());
3688 }
3689
3690 #[test]
3691 fn qwen_image_text_encoder_uses_gpu_on_metal() {
3692 let plan = QwenImageEngine::qwen2_text_encoder_plan_for_mode(
3693 Qwen2TextEncoderMode::Auto,
3694 false,
3695 true,
3696 &resolved_text_encoder(true, true),
3697 );
3698 assert!(plan.use_gpu);
3699 assert!(!plan.use_cpu_staging);
3700 }
3701
3702 #[test]
3703 fn qwen_image_text_encoder_uses_gpu_on_cuda_with_headroom() {
3704 let plan = QwenImageEngine::qwen2_text_encoder_plan_for_mode(
3705 Qwen2TextEncoderMode::Auto,
3706 true,
3707 false,
3708 &resolved_text_encoder(false, true),
3709 );
3710 assert!(plan.use_gpu);
3711 assert!(!plan.use_cpu_staging);
3712 }
3713
3714 #[test]
3715 fn qwen_image_text_encoder_uses_cpu_on_cuda_without_headroom() {
3716 let plan = QwenImageEngine::qwen2_text_encoder_plan_for_mode(
3717 Qwen2TextEncoderMode::Auto,
3718 true,
3719 false,
3720 &resolved_text_encoder(false, false),
3721 );
3722 assert!(!plan.use_gpu);
3723 assert!(!plan.use_cpu_staging);
3724 }
3725
3726 #[test]
3727 fn qwen_image_cpu_safetensors_text_encoder_stays_f32() {
3728 assert_eq!(
3729 QwenImageEngine::text_encoder_load_dtype(false, DType::BF16),
3730 DType::F32
3731 );
3732 }
3733
3734 #[test]
3735 fn qwen_image_cpu_gguf_text_encoder_stays_f32() {
3736 assert_eq!(
3737 QwenImageEngine::text_encoder_load_dtype(false, DType::BF16),
3738 DType::F32
3739 );
3740 }
3741
3742 #[test]
3743 fn qwen_image_text_encoder_gpu_override_disables_metal_staging() {
3744 let plan = QwenImageEngine::qwen2_text_encoder_plan_for_mode(
3745 Qwen2TextEncoderMode::Gpu,
3746 false,
3747 true,
3748 &resolved_text_encoder(true, true),
3749 );
3750 assert!(plan.use_gpu);
3751 assert!(!plan.use_cpu_staging);
3752 }
3753
3754 #[test]
3755 fn qwen_image_auto_prefers_q6_on_metal_with_headroom() {
3756 let q6 = mold_core::manifest::find_qwen2_vl_variant("q6").unwrap();
3757 let resolved = QwenImageEngine::choose_text_encoder_source(
3758 Some("auto"),
3759 false,
3760 true,
3761 qwen2_vram_threshold(q6.size_bytes) + 1,
3762 16_600_000_000,
3763 Qwen2TextEncoderUsage::Resident,
3764 )
3765 .unwrap();
3766 assert!(resolved.is_gguf);
3767 assert_eq!(resolved.variant_label, "q6");
3768 assert!(resolved.auto_use_gpu);
3769 }
3770
3771 #[test]
3772 fn qwen_image_auto_falls_back_to_q4_on_metal_when_q6_does_not_fit() {
3773 let q4 = mold_core::manifest::find_qwen2_vl_variant("q4").unwrap();
3774 let q6 = mold_core::manifest::find_qwen2_vl_variant("q6").unwrap();
3775 let free_vram = qwen2_vram_threshold(q4.size_bytes);
3776 assert!(free_vram < qwen2_vram_threshold(q6.size_bytes));
3777
3778 let resolved = QwenImageEngine::choose_text_encoder_source(
3779 Some("auto"),
3780 false,
3781 true,
3782 free_vram,
3783 0,
3784 Qwen2TextEncoderUsage::Resident,
3785 )
3786 .unwrap();
3787 assert!(resolved.is_gguf);
3788 assert_eq!(resolved.variant_label, "q4");
3789 assert!(resolved.auto_use_gpu);
3790 }
3791
3792 #[test]
3793 fn qwen_image_auto_keeps_bf16_default_on_cuda() {
3794 let resolved = QwenImageEngine::choose_text_encoder_source(
3795 Some("auto"),
3796 true,
3797 false,
3798 QWEN2_FP16_VRAM_THRESHOLD + 1,
3799 16_600_000_000,
3800 Qwen2TextEncoderUsage::Resident,
3801 )
3802 .unwrap();
3803 assert!(!resolved.is_gguf);
3804 assert_eq!(resolved.variant_label, "bf16");
3805 assert!(resolved.auto_use_gpu);
3806 }
3807
3808 #[test]
3809 fn qwen_image_auto_prefers_quantized_gpu_on_cuda_for_resident_mode_when_it_fits() {
3810 let resolved = QwenImageEngine::choose_text_encoder_source(
3811 Some("auto"),
3812 true,
3813 false,
3814 QWEN2_FP16_VRAM_THRESHOLD - 1,
3815 16_600_000_000,
3816 Qwen2TextEncoderUsage::Resident,
3817 )
3818 .unwrap();
3819 assert!(resolved.is_gguf);
3820 assert_eq!(resolved.variant_label, "q4");
3821 assert!(resolved.auto_use_gpu);
3822 }
3823
3824 #[test]
3825 fn qwen_image_auto_uses_quantized_cpu_fallback_on_cuda_for_resident_mode() {
3826 let resolved = QwenImageEngine::choose_text_encoder_source(
3827 Some("auto"),
3828 true,
3829 false,
3830 1,
3831 16_600_000_000,
3832 Qwen2TextEncoderUsage::Resident,
3833 )
3834 .unwrap();
3835 assert!(resolved.is_gguf);
3836 assert_eq!(resolved.variant_label, "q4");
3837 assert!(!resolved.auto_use_gpu);
3838 }
3839
3840 #[test]
3841 fn qwen_image_auto_prefers_quantized_gpu_on_cuda_for_sequential_mode_when_it_fits() {
3842 let resolved = QwenImageEngine::choose_text_encoder_source(
3843 Some("auto"),
3844 true,
3845 false,
3846 QWEN2_FP16_VRAM_THRESHOLD - 1,
3847 16_600_000_000,
3848 Qwen2TextEncoderUsage::Sequential,
3849 )
3850 .unwrap();
3851 assert!(resolved.is_gguf);
3852 assert_eq!(resolved.variant_label, "q4");
3853 assert!(resolved.auto_use_gpu);
3854 }
3855
3856 #[test]
3857 fn qwen_image_auto_uses_quantized_cpu_fallback_on_cuda_for_sequential_mode() {
3858 let resolved = QwenImageEngine::choose_text_encoder_source(
3859 Some("auto"),
3860 true,
3861 false,
3862 1,
3863 16_600_000_000,
3864 Qwen2TextEncoderUsage::Sequential,
3865 )
3866 .unwrap();
3867 assert!(resolved.is_gguf);
3868 assert_eq!(resolved.variant_label, "q4");
3869 assert!(!resolved.auto_use_gpu);
3870 }
3871
3872 #[test]
3873 fn qwen_image_explicit_q6_respects_cpu_fallback_on_cuda() {
3874 let resolved = QwenImageEngine::choose_text_encoder_source(
3875 Some("q6"),
3876 true,
3877 false,
3878 1,
3879 0,
3880 Qwen2TextEncoderUsage::Resident,
3881 )
3882 .unwrap();
3883 assert!(resolved.is_gguf);
3884 assert_eq!(resolved.variant_label, "q6");
3885 assert!(!resolved.auto_use_gpu);
3886 }
3887
3888 #[test]
3889 fn qwen_image_edit_accepts_quantized_text_with_bf16_vision_sidecar() {
3890 let dir = temp_test_dir("qwen-image-edit-text-encoder");
3891 let transformer = touch(&dir, "qwen-image-edit.gguf");
3892 let vae = touch(&dir, "vae.safetensors");
3893 let tokenizer = touch(&dir, "tokenizer.json");
3894 let mut paths = qwen_image_model_paths(transformer, vec![], vae, Some(tokenizer));
3895 paths.text_encoder_files = vec![touch(&dir, "text-encoder-00001-of-00004.safetensors")];
3896 let engine = QwenImageEngine::new(
3897 "qwen-image-edit-2511:q4".to_string(),
3898 paths,
3899 LoadStrategy::Sequential,
3900 0,
3901 false,
3902 None,
3903 );
3904
3905 let resolved = engine
3906 .resolve_text_encoder_source_with_preference(
3907 &Device::Cpu,
3908 0,
3909 Qwen2TextEncoderUsage::Sequential,
3910 Some("auto"),
3911 )
3912 .unwrap();
3913 assert!(!resolved.vision_paths.is_empty());
3914
3915 let resolved = engine
3916 .resolve_text_encoder_source_with_preference(
3917 &Device::Cpu,
3918 0,
3919 Qwen2TextEncoderUsage::Sequential,
3920 Some("q4"),
3921 )
3922 .unwrap();
3923 assert!(resolved.is_gguf);
3924 assert_eq!(resolved.variant_label, "q4");
3925 assert_eq!(resolved.vision_paths.len(), 1);
3926
3927 let resolved = engine
3928 .resolve_text_encoder_source_with_preference(
3929 &Device::Cpu,
3930 0,
3931 Qwen2TextEncoderUsage::Sequential,
3932 Some("bf16"),
3933 )
3934 .unwrap();
3935 assert!(!resolved.is_gguf);
3936 assert_eq!(resolved.variant_label, "bf16");
3937 assert_eq!(resolved.vision_paths.len(), 1);
3938 }
3939
3940 #[test]
3941 fn qwen_image_edit_prompt_numbers_each_picture_placeholder() {
3942 let prompt = QwenImageEngine::qwen_image_edit_prompt("swap materials", 3);
3943 assert!(prompt.contains(QWEN_IMAGE_EDIT_SYSTEM_PROMPT));
3944 assert!(prompt.contains("Picture 1: <|vision_start|><|image_pad|><|vision_end|>"));
3945 assert!(prompt.contains("Picture 2: <|vision_start|><|image_pad|><|vision_end|>"));
3946 assert!(prompt.contains("Picture 3: <|vision_start|><|image_pad|><|vision_end|>"));
3947 assert!(prompt.ends_with("<|im_start|>assistant\n"));
3948 }
3949
3950 #[test]
3951 fn qwen_image_edit_image_dims_fit_target_area_with_16px_alignment() {
3952 let bytes = png_with_dimensions(1600, 900);
3953 let (width, height) =
3954 QwenImageEngine::qwen_image_edit_image_dims(&bytes, QWEN_IMAGE_EDIT_VAE_AREA).unwrap();
3955 assert_eq!((width, height), (1360, 768));
3956 assert_eq!(width % 16, 0);
3957 assert_eq!(height % 16, 0);
3958 }
3959
3960 #[test]
3961 fn pack_and_unpack_latents_roundtrip() {
3962 let values: Vec<f32> = (0..(16 * 4 * 6)).map(|i| i as f32).collect();
3963 let latents = Tensor::from_vec(values.clone(), (1, 16, 4, 6), &Device::Cpu).unwrap();
3964 let packed = QwenImageEngine::pack_latents_4d(&latents).unwrap();
3965 assert_eq!(packed.dims3().unwrap(), (1, 6, 64));
3966
3967 let unpacked = QwenImageEngine::unpack_latents_packed(&packed, 4, 6).unwrap();
3968 assert_eq!(
3969 unpacked.flatten_all().unwrap().to_vec1::<f32>().unwrap(),
3970 values
3971 );
3972 }
3973
3974 #[test]
3975 fn quantized_cuda_cfg_headroom_scales_with_resolution() {
3976 let native = QwenImageEngine::quantized_cuda_cfg_headroom(1328, 1328);
3977 let reduced = QwenImageEngine::quantized_cuda_cfg_headroom(512, 512);
3978 assert_eq!(native, QWEN_GGUF_NATIVE_CFG_HEADROOM);
3979 assert_eq!(reduced, QWEN_GGUF_MIN_CFG_HEADROOM);
3980 }
3981
3982 #[test]
3983 fn qwen_quantized_native_resolution_uses_split_cfg_on_24gb_cuda() {
3984 assert!(QwenImageEngine::should_split_cfg_quantized_cuda(
3985 12_300_000_000,
3986 24_600_000_000,
3987 1328,
3988 1328,
3989 ));
3990 }
3991
3992 #[test]
3993 fn qwen_quantized_reduced_resolution_keeps_batched_cfg_when_it_fits() {
3994 assert!(!QwenImageEngine::should_split_cfg_quantized_cuda(
3995 12_300_000_000,
3996 24_600_000_000,
3997 512,
3998 512,
3999 ));
4000 }
4001
4002 #[test]
4003 fn qwen_quantized_cfg_split_boundary_does_not_split_when_estimate_exactly_fits() {
4004 let headroom = QwenImageEngine::quantized_cuda_cfg_headroom(1328, 1328);
4005 let transformer_size = 12_300_000_000;
4006 let free_vram = transformer_size + headroom;
4007 assert!(!QwenImageEngine::should_split_cfg_quantized_cuda(
4008 transformer_size,
4009 free_vram,
4010 1328,
4011 1328,
4012 ));
4013 }
4014
4015 #[test]
4016 fn qwen_quantized_unknown_vram_biases_to_split_cfg() {
4017 assert!(QwenImageEngine::should_split_cfg_quantized_cuda(
4018 12_300_000_000,
4019 0,
4020 1328,
4021 1328,
4022 ));
4023 }
4024
4025 #[test]
4026 fn qwen_is_oom_error_matches_cuda_memory_allocation_string() {
4027 assert!(QwenImageEngine::is_oom_error(&"cudaErrorMemoryAllocation"));
4028 }
4029
4030 #[test]
4031 fn qwen_debug_stats_counts_nan_and_inf() {
4032 let tensor = Tensor::from_vec(
4033 vec![0.0f32, 1.0, f32::NAN, f32::INFINITY, f32::NEG_INFINITY],
4034 Shape::from((5,)),
4035 &Device::Cpu,
4036 )
4037 .unwrap();
4038
4039 let stats = QwenImageEngine::tensor_stats(&tensor).unwrap();
4040
4041 assert_eq!(stats.total, 5);
4042 assert_eq!(stats.nan_count, 1);
4043 assert_eq!(stats.pos_inf_count, 1);
4044 assert_eq!(stats.neg_inf_count, 1);
4045 assert_eq!(stats.min, 0.0);
4046 assert_eq!(stats.max, 1.0);
4047 assert_eq!(stats.mean, 0.5);
4048 }
4049
4050 #[test]
4051 fn qwen_debug_stats_detects_near_black_postprocessed_image() {
4052 let stats = QwenTensorStats {
4053 min: 0.0,
4054 max: 0.01,
4055 mean: 0.004,
4056 nan_count: 0,
4057 pos_inf_count: 0,
4058 neg_inf_count: 0,
4059 total: 1024,
4060 };
4061
4062 assert!(QwenImageEngine::near_black_image_stats(stats));
4063 }
4064
4065 #[test]
4066 fn qwen_debug_stats_does_not_flag_non_black_image() {
4067 let stats = QwenTensorStats {
4068 min: 0.0,
4069 max: 0.75,
4070 mean: 0.18,
4071 nan_count: 0,
4072 pos_inf_count: 0,
4073 neg_inf_count: 0,
4074 total: 1024,
4075 };
4076
4077 assert!(!QwenImageEngine::near_black_image_stats(stats));
4078 }
4079
4080 #[test]
4081 fn qwen_debug_stats_formats_progress_message() {
4082 let stats = QwenTensorStats {
4083 min: 0.0,
4084 max: 1.0,
4085 mean: 0.5,
4086 nan_count: 2,
4087 pos_inf_count: 1,
4088 neg_inf_count: 1,
4089 total: 10,
4090 };
4091
4092 let message = QwenImageEngine::format_tensor_stats("sample", stats);
4093
4094 assert!(message.contains("NaN=2/10"));
4095 assert!(message.contains("+Inf=1"));
4096 assert!(message.contains("-Inf=1"));
4097 }
4098
4099 #[test]
4100 fn qwen_oom_fallback_returns_primary_success_without_running_fallback() {
4101 let mut progress = ProgressReporter::default();
4102 let messages = std::sync::Arc::new(std::sync::Mutex::new(Vec::<String>::new()));
4103 let messages_clone = messages.clone();
4104 progress.set_callback(Box::new(move |event| {
4105 if let ProgressEvent::Info { message } = event {
4106 messages_clone.lock().unwrap().push(message);
4107 }
4108 }));
4109
4110 let fallback_called = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false));
4111 let fallback_called_clone = fallback_called.clone();
4112 let value = QwenImageEngine::with_cuda_oom_cpu_fallback(
4113 || Ok(7usize),
4114 || {
4115 fallback_called_clone.store(true, std::sync::atomic::Ordering::SeqCst);
4116 Ok(9usize)
4117 },
4118 true,
4119 &Device::Cpu,
4120 &progress,
4121 "retrying",
4122 |_| true,
4123 )
4124 .unwrap();
4125
4126 assert_eq!(value, 7);
4127 assert!(!fallback_called.load(std::sync::atomic::Ordering::SeqCst));
4128 assert!(messages.lock().unwrap().is_empty());
4129 }
4130
4131 #[test]
4132 fn qwen_oom_fallback_retries_when_primary_ooms_on_cuda() {
4133 let mut progress = ProgressReporter::default();
4134 let messages = std::sync::Arc::new(std::sync::Mutex::new(Vec::<String>::new()));
4135 let messages_clone = messages.clone();
4136 progress.set_callback(Box::new(move |event| {
4137 if let ProgressEvent::Info { message } = event {
4138 messages_clone.lock().unwrap().push(message);
4139 }
4140 }));
4141
4142 let value = QwenImageEngine::with_cuda_oom_cpu_fallback(
4143 || Err(anyhow::anyhow!("cudaErrorMemoryAllocation")),
4144 || Ok(11usize),
4145 true,
4146 &Device::Cpu,
4147 &progress,
4148 "retrying",
4149 QwenImageEngine::is_oom_error,
4150 )
4151 .unwrap();
4152
4153 assert_eq!(value, 11);
4154 assert_eq!(messages.lock().unwrap().as_slice(), ["retrying"]);
4155 }
4156
4157 #[test]
4158 fn qwen_oom_fallback_does_not_retry_non_oom_errors() {
4159 let progress = ProgressReporter::default();
4160 let err = QwenImageEngine::with_cuda_oom_cpu_fallback(
4161 || Err(anyhow::anyhow!("not an oom")),
4162 || Ok(11usize),
4163 true,
4164 &Device::Cpu,
4165 &progress,
4166 "retrying",
4167 QwenImageEngine::is_oom_error,
4168 )
4169 .unwrap_err();
4170
4171 assert!(err.to_string().contains("not an oom"));
4172 }
4173
4174 #[test]
4175 fn qwen_tiled_fallback_returns_primary_success_without_retrying() {
4176 let progress = ProgressReporter::default();
4177 let tiled_called = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false));
4178 let cpu_called = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false));
4179 let tiled_called_clone = tiled_called.clone();
4180 let cpu_called_clone = cpu_called.clone();
4181
4182 let value = QwenImageEngine::with_cuda_tiled_then_cpu_fallback(
4183 || Ok(5usize),
4184 || {
4185 tiled_called_clone.store(true, std::sync::atomic::Ordering::SeqCst);
4186 Ok(7usize)
4187 },
4188 || {
4189 cpu_called_clone.store(true, std::sync::atomic::Ordering::SeqCst);
4190 Ok(9usize)
4191 },
4192 true,
4193 false,
4194 &Device::Cpu,
4195 &progress,
4196 "tiled",
4197 "cpu",
4198 |_| true,
4199 )
4200 .unwrap();
4201
4202 assert_eq!(value, 5);
4203 assert!(!tiled_called.load(std::sync::atomic::Ordering::SeqCst));
4204 assert!(!cpu_called.load(std::sync::atomic::Ordering::SeqCst));
4205 }
4206
4207 #[test]
4208 fn qwen_tiled_fallback_uses_tiled_result_before_cpu() {
4209 let mut progress = ProgressReporter::default();
4210 let messages = std::sync::Arc::new(std::sync::Mutex::new(Vec::<String>::new()));
4211 let messages_clone = messages.clone();
4212 progress.set_callback(Box::new(move |event| {
4213 if let ProgressEvent::Info { message } = event {
4214 messages_clone.lock().unwrap().push(message);
4215 }
4216 }));
4217
4218 let cpu_called = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false));
4219 let cpu_called_clone = cpu_called.clone();
4220 let value = QwenImageEngine::with_cuda_tiled_then_cpu_fallback(
4221 || Err(anyhow::anyhow!("out of memory")),
4222 || Ok(13usize),
4223 || {
4224 cpu_called_clone.store(true, std::sync::atomic::Ordering::SeqCst);
4225 Ok(17usize)
4226 },
4227 true,
4228 false,
4229 &Device::Cpu,
4230 &progress,
4231 "tiled",
4232 "cpu",
4233 QwenImageEngine::is_oom_error,
4234 )
4235 .unwrap();
4236
4237 assert_eq!(value, 13);
4238 assert!(!cpu_called.load(std::sync::atomic::Ordering::SeqCst));
4239 assert_eq!(messages.lock().unwrap().as_slice(), ["tiled"]);
4240 }
4241
4242 #[test]
4243 fn qwen_tiled_fallback_uses_cpu_after_tiled_oom() {
4244 let mut progress = ProgressReporter::default();
4245 let messages = std::sync::Arc::new(std::sync::Mutex::new(Vec::<String>::new()));
4246 let messages_clone = messages.clone();
4247 progress.set_callback(Box::new(move |event| {
4248 if let ProgressEvent::Info { message } = event {
4249 messages_clone.lock().unwrap().push(message);
4250 }
4251 }));
4252
4253 let value = QwenImageEngine::with_cuda_tiled_then_cpu_fallback(
4254 || Err(anyhow::anyhow!("OUT_OF_MEMORY")),
4255 || Err(anyhow::anyhow!("OUT_OF_MEMORY")),
4256 || Ok(19usize),
4257 true,
4258 false,
4259 &Device::Cpu,
4260 &progress,
4261 "tiled",
4262 "cpu",
4263 QwenImageEngine::is_oom_error,
4264 )
4265 .unwrap();
4266
4267 assert_eq!(value, 19);
4268 assert_eq!(messages.lock().unwrap().as_slice(), ["tiled", "cpu"]);
4269 }
4270
4271 #[test]
4272 fn qwen_tiled_fallback_propagates_non_oom_tiled_error() {
4273 let progress = ProgressReporter::default();
4274 let err = QwenImageEngine::with_cuda_tiled_then_cpu_fallback(
4275 || Err(anyhow::anyhow!("out of memory")),
4276 || Err(anyhow::anyhow!("bad tiled decode")),
4277 || Ok(19usize),
4278 true,
4279 false,
4280 &Device::Cpu,
4281 &progress,
4282 "tiled",
4283 "cpu",
4284 QwenImageEngine::is_oom_error,
4285 )
4286 .unwrap_err();
4287
4288 assert!(err.to_string().contains("bad tiled decode"));
4289 }
4290
4291 #[test]
4292 fn qwen_proactive_tiled_policy_selects_native_cuda_under_pressure() {
4293 assert!(QwenImageEngine::should_proactively_tile_vae_decode(
4294 1328,
4295 1328,
4296 true,
4297 6_000_000_000
4298 ));
4299 assert!(!QwenImageEngine::should_proactively_tile_vae_decode(
4300 512,
4301 512,
4302 true,
4303 6_000_000_000
4304 ));
4305 assert!(!QwenImageEngine::should_proactively_tile_vae_decode(
4306 1328,
4307 1328,
4308 false,
4309 6_000_000_000
4310 ));
4311 assert!(!QwenImageEngine::should_proactively_tile_vae_decode(
4312 1328,
4313 1328,
4314 true,
4315 16_000_000_000
4316 ));
4317 }
4318
4319 #[test]
4320 fn qwen_proactive_tiled_decode_skips_primary_full_decode() {
4321 let mut progress = ProgressReporter::default();
4322 let messages = std::sync::Arc::new(std::sync::Mutex::new(Vec::<String>::new()));
4323 let messages_clone = messages.clone();
4324 progress.set_callback(Box::new(move |event| {
4325 if let ProgressEvent::Info { message } = event {
4326 messages_clone.lock().unwrap().push(message);
4327 }
4328 }));
4329
4330 let primary_called = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false));
4331 let primary_called_clone = primary_called.clone();
4332 let value = QwenImageEngine::with_cuda_tiled_then_cpu_fallback(
4333 || {
4334 primary_called_clone.store(true, std::sync::atomic::Ordering::SeqCst);
4335 Ok(3usize)
4336 },
4337 || Ok(7usize),
4338 || Ok(9usize),
4339 true,
4340 true,
4341 &Device::Cpu,
4342 &progress,
4343 "tiled after oom",
4344 "cpu",
4345 QwenImageEngine::is_oom_error,
4346 )
4347 .unwrap();
4348
4349 assert_eq!(value, 7);
4350 assert!(!primary_called.load(std::sync::atomic::Ordering::SeqCst));
4351 assert_eq!(
4352 messages.lock().unwrap().as_slice(),
4353 ["Selecting tiled GPU VAE decode proactively"]
4354 );
4355 }
4356
4357 #[test]
4358 fn qwen_hot_text_encoder_keeps_gpu_after_cache_miss_with_headroom() {
4359 let action = QwenImageEngine::qwen2_text_encoder_post_encode_action(
4360 Qwen2TextEncoderResidencyInput {
4361 on_gpu: true,
4362 is_quantized: true,
4363 is_metal: false,
4364 keep_te_ram: false,
4365 prompt_cache_miss: true,
4366 transformer_resident: true,
4367 free_vram_bytes: 10_000_000_000,
4368 required_vram_bytes: 8_000_000_000,
4369 },
4370 );
4371
4372 assert_eq!(action, Qwen2TextEncoderPostEncodeAction::KeepGpu);
4373 }
4374
4375 #[test]
4376 fn qwen_hot_text_encoder_drops_after_cache_hit_even_with_headroom() {
4377 let action = QwenImageEngine::qwen2_text_encoder_post_encode_action(
4378 Qwen2TextEncoderResidencyInput {
4379 on_gpu: true,
4380 is_quantized: true,
4381 is_metal: false,
4382 keep_te_ram: false,
4383 prompt_cache_miss: false,
4384 transformer_resident: true,
4385 free_vram_bytes: 10_000_000_000,
4386 required_vram_bytes: 8_000_000_000,
4387 },
4388 );
4389
4390 assert_eq!(action, Qwen2TextEncoderPostEncodeAction::Drop);
4391 }
4392
4393 #[test]
4394 fn qwen_hot_text_encoder_drops_under_transformer_pressure() {
4395 let action = QwenImageEngine::qwen2_text_encoder_post_encode_action(
4396 Qwen2TextEncoderResidencyInput {
4397 on_gpu: true,
4398 is_quantized: true,
4399 is_metal: false,
4400 keep_te_ram: false,
4401 prompt_cache_miss: true,
4402 transformer_resident: true,
4403 free_vram_bytes: 7_999_999_999,
4404 required_vram_bytes: 8_000_000_000,
4405 },
4406 );
4407
4408 assert_eq!(action, Qwen2TextEncoderPostEncodeAction::Drop);
4409 }
4410
4411 #[test]
4412 fn qwen_hot_text_encoder_parks_bf16_when_keep_ram_enabled() {
4413 let action = QwenImageEngine::qwen2_text_encoder_post_encode_action(
4414 Qwen2TextEncoderResidencyInput {
4415 on_gpu: true,
4416 is_quantized: false,
4417 is_metal: false,
4418 keep_te_ram: true,
4419 prompt_cache_miss: true,
4420 transformer_resident: true,
4421 free_vram_bytes: 7_999_999_999,
4422 required_vram_bytes: 8_000_000_000,
4423 },
4424 );
4425
4426 assert_eq!(action, Qwen2TextEncoderPostEncodeAction::ParkCpu);
4427 }
4428
4429 #[test]
4430 fn qwen_hot_text_encoder_never_parks_quantized() {
4431 let action = QwenImageEngine::qwen2_text_encoder_post_encode_action(
4432 Qwen2TextEncoderResidencyInput {
4433 on_gpu: true,
4434 is_quantized: true,
4435 is_metal: false,
4436 keep_te_ram: true,
4437 prompt_cache_miss: true,
4438 transformer_resident: true,
4439 free_vram_bytes: 7_999_999_999,
4440 required_vram_bytes: 8_000_000_000,
4441 },
4442 );
4443
4444 assert_eq!(action, Qwen2TextEncoderPostEncodeAction::Drop);
4445 }
4446
4447 #[test]
4448 fn qwen_hot_text_encoder_drops_when_transformer_not_resident() {
4449 let action = QwenImageEngine::qwen2_text_encoder_post_encode_action(
4450 Qwen2TextEncoderResidencyInput {
4451 on_gpu: true,
4452 is_quantized: true,
4453 is_metal: false,
4454 keep_te_ram: false,
4455 prompt_cache_miss: true,
4456 transformer_resident: false,
4457 free_vram_bytes: 10_000_000_000,
4458 required_vram_bytes: 8_000_000_000,
4459 },
4460 );
4461
4462 assert_eq!(action, Qwen2TextEncoderPostEncodeAction::Drop);
4463 }
4464
4465 #[test]
4466 fn qwen_transformer_hot_vae_eligibility_requires_quantized_cuda_components() {
4467 assert!(QwenImageEngine::qwen_transformer_can_stay_hot_for_vae(
4468 true, true, true
4469 ));
4470 assert!(!QwenImageEngine::qwen_transformer_can_stay_hot_for_vae(
4471 false, true, true
4472 ));
4473 assert!(!QwenImageEngine::qwen_transformer_can_stay_hot_for_vae(
4474 true, false, true
4475 ));
4476 assert!(!QwenImageEngine::qwen_transformer_can_stay_hot_for_vae(
4477 true, true, false
4478 ));
4479 }
4480
4481 #[test]
4482 fn qwen_transformer_paths_prefer_shards_when_present() {
4483 let dir = temp_test_dir("mold-qwen-shards");
4484 let shard_a = touch(&dir, "transformer-00001-of-00002.safetensors");
4485 let shard_b = touch(&dir, "transformer-00002-of-00002.safetensors");
4486 let engine = QwenImageEngine::new(
4487 "qwen-image:q4".to_string(),
4488 qwen_image_model_paths(
4489 dir.join("transformer.safetensors"),
4490 vec![shard_a.clone(), shard_b.clone()],
4491 dir.join("vae.safetensors"),
4492 Some(dir.join("tokenizer.json")),
4493 ),
4494 LoadStrategy::Sequential,
4495 0,
4496 false,
4497 None,
4498 );
4499
4500 assert_eq!(engine.transformer_paths(), vec![shard_a, shard_b]);
4501
4502 fs::remove_dir_all(dir).ok();
4503 }
4504
4505 #[test]
4506 fn qwen_validate_paths_accepts_existing_files() {
4507 let dir = temp_test_dir("mold-qwen-validate-ok");
4508 let shard_a = touch(&dir, "transformer-00001-of-00002.safetensors");
4509 let shard_b = touch(&dir, "transformer-00002-of-00002.safetensors");
4510 let vae = touch(&dir, "vae.safetensors");
4511 let tokenizer = touch(&dir, "tokenizer.json");
4512 let gguf = touch(&dir, "transformer.gguf");
4513
4514 let sharded = QwenImageEngine::new(
4515 "qwen-image:bf16".to_string(),
4516 qwen_image_model_paths(
4517 dir.join("transformer.safetensors"),
4518 vec![shard_a, shard_b],
4519 vae.clone(),
4520 Some(tokenizer.clone()),
4521 ),
4522 LoadStrategy::Sequential,
4523 0,
4524 false,
4525 None,
4526 );
4527 assert_eq!(sharded.validate_paths().unwrap(), tokenizer);
4528 assert!(!sharded.detect_is_quantized());
4529
4530 let quantized = QwenImageEngine::new(
4531 "qwen-image:q4".to_string(),
4532 qwen_image_model_paths(gguf, vec![], vae, Some(dir.join("tokenizer.json"))),
4533 LoadStrategy::Sequential,
4534 0,
4535 false,
4536 None,
4537 );
4538 assert!(quantized.detect_is_quantized());
4539
4540 fs::remove_dir_all(dir).ok();
4541 }
4542
4543 #[test]
4544 fn qwen_validate_paths_requires_text_tokenizer() {
4545 let dir = temp_test_dir("mold-qwen-validate-missing");
4546 let engine = QwenImageEngine::new(
4547 "qwen-image:q4".to_string(),
4548 qwen_image_model_paths(
4549 dir.join("transformer.gguf"),
4550 vec![],
4551 dir.join("vae.safetensors"),
4552 None,
4553 ),
4554 LoadStrategy::Sequential,
4555 0,
4556 false,
4557 None,
4558 );
4559
4560 let err = engine.validate_paths().unwrap_err();
4561 assert!(err.to_string().contains("text tokenizer path required"));
4562
4563 fs::remove_dir_all(dir).ok();
4564 }
4565
4566 #[test]
4567 fn qwen_image_loads_text_tokenizer_through_shared_pool() {
4568 let dir = temp_test_dir("mold-qwen-tokenizer-pool");
4569 let tokenizer_path = dir.join("tokenizer.json");
4570 tokenizers::Tokenizer::new(BPE::default())
4571 .save(&tokenizer_path, false)
4572 .unwrap();
4573
4574 let shared_pool = Arc::new(Mutex::new(SharedPool::new()));
4575 let pooled = shared_pool
4576 .lock()
4577 .unwrap()
4578 .load_tokenizer(&tokenizer_path)
4579 .unwrap();
4580
4581 let engine = QwenImageEngine::new(
4582 "qwen-image:q4".to_string(),
4583 qwen_image_model_paths(
4584 dir.join("transformer.gguf"),
4585 vec![],
4586 dir.join("vae.safetensors"),
4587 Some(tokenizer_path.clone()),
4588 ),
4589 LoadStrategy::Sequential,
4590 0,
4591 false,
4592 Some(shared_pool),
4593 );
4594
4595 let loaded = engine.load_text_tokenizer(&tokenizer_path).unwrap();
4596
4597 assert!(Arc::ptr_eq(&pooled, &loaded));
4598 fs::remove_dir_all(dir).ok();
4599 }
4600
4601 #[test]
4602 fn qwen_image_loads_vae_tensors_through_shared_pool() {
4603 let dir = temp_test_dir("mold-qwen-vae-pool");
4604 let vae_path = dir.join("vae.safetensors");
4605 let weight = 1.0f32.to_le_bytes();
4606 let mut tensors = HashMap::new();
4607 tensors.insert(
4608 "encoder.conv_in.weight".to_string(),
4609 TensorView::new(SafeDtype::F32, vec![1], &weight).unwrap(),
4610 );
4611 serialize_to_file(&tensors, &None, &vae_path).unwrap();
4612
4613 let shared_pool = Arc::new(Mutex::new(SharedPool::new()));
4614 let pooled = shared_pool
4615 .lock()
4616 .unwrap()
4617 .load_safetensors_cpu_tensors(std::slice::from_ref(&vae_path))
4618 .unwrap()
4619 .unwrap();
4620
4621 let engine = QwenImageEngine::new(
4622 "qwen-image:q4".to_string(),
4623 qwen_image_model_paths(
4624 dir.join("transformer.gguf"),
4625 vec![],
4626 vae_path.clone(),
4627 Some(dir.join("tokenizer.json")),
4628 ),
4629 LoadStrategy::Sequential,
4630 0,
4631 false,
4632 Some(shared_pool),
4633 );
4634
4635 let loaded = engine.load_vae_cpu_tensors().unwrap().unwrap();
4636
4637 assert!(Arc::ptr_eq(&pooled, &loaded));
4638 fs::remove_dir_all(dir).ok();
4639 }
4640
4641 #[test]
4642 fn qwen_img2img_uses_minus_one_to_one_source_normalization() {
4643 assert_eq!(
4644 QwenImageEngine::img2img_source_normalize_range(),
4645 img_utils::NormalizeRange::MinusOneToOne
4646 );
4647 }
4648}