1use anyhow::{bail, Result};
18use candle_core::{DType, Device, IndexOp, Tensor, D};
19use candle_transformers::models::z_image::postprocess_image;
20use candle_transformers::quantized_var_builder;
21use mold_core::{fit_to_target_area, GenerateRequest, GenerateResponse, ImageData, ModelPaths};
22use std::collections::HashMap;
23use std::path::Path;
24use std::sync::{Arc, Mutex};
25use std::time::Instant;
26use tokenizers::Tokenizer;
27
28use super::quantized_transformer::QuantizedQwenImageTransformer2DModel;
29use super::sampling::{image_seq_len, QwenImageScheduler};
30use super::transformer::{QwenImageConfig, QwenImageTransformer2DModel};
31use super::vae::QwenImageVae;
32use crate::cache::{
33 clear_cache, prompt_text_key, CachedTensor, LruCache, DEFAULT_PROMPT_CACHE_CAPACITY,
34};
35use crate::device::{
36 effective_device_ref, fits_in_memory, fmt_gb, free_vram_bytes, memory_status_string,
37 preflight_memory_check, qwen2_vram_threshold, should_use_gpu, usable_free_vram_bytes,
38};
39use crate::encoders;
40use crate::engine::{rand_seed, InferenceEngine, LoadStrategy};
41use crate::engine_base::EngineBase;
42use crate::image::{build_output_metadata, encode_image};
43use crate::img_utils;
44use crate::progress::{ProgressCallback, ProgressEvent, ProgressReporter};
45use crate::upscaler::tiling::{upscale_with_tiling, TilingConfig};
46
47const VAE_DECODE_VRAM_THRESHOLD: u64 = 2_500_000_000;
50const QWEN_EMPTY_NEGATIVE_PROMPT: &str = " ";
53const QWEN_NATIVE_WIDTH: usize = 1328;
54const QWEN_NATIVE_HEIGHT: usize = 1328;
55const QWEN_GGUF_NATIVE_CFG_HEADROOM: u64 = 14_000_000_000;
56const QWEN_GGUF_MIN_CFG_HEADROOM: u64 = 3_000_000_000;
57const QWEN_VAE_TILE_SIZES: [u32; 3] = [64, 32, 16];
58const QWEN_IMAGE_EDIT_VAE_AREA: u32 = 1024 * 1024;
59const QWEN_IMAGE_EDIT_SYSTEM_PROMPT: &str = "Describe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.";
60
61const QWEN2_FP16_VRAM_THRESHOLD: u64 = 16_000_000_000;
64const QWEN2_HOT_TE_RESIDENCY_HEADROOM: u64 = 1_000_000_000;
68
69#[derive(Debug, Clone, Copy, PartialEq, Eq)]
70enum Qwen2TextEncoderMode {
71 Auto,
72 Gpu,
73 CpuStage,
74 Cpu,
75}
76
77impl Qwen2TextEncoderMode {
78 fn from_env() -> Self {
79 match std::env::var("MOLD_QWEN2_TEXT_ENCODER_MODE")
80 .unwrap_or_default()
81 .to_ascii_lowercase()
82 .as_str()
83 {
84 "gpu" => Self::Gpu,
85 "cpu-stage" => Self::CpuStage,
86 "cpu_stage" => Self::CpuStage,
87 "cpu" => Self::Cpu,
88 _ => Self::Auto,
89 }
90 }
91}
92
93#[derive(Debug, Clone, Copy, PartialEq, Eq)]
94struct Qwen2TextEncoderPlan {
95 use_gpu: bool,
96 use_cpu_staging: bool,
97}
98
99#[derive(Debug, Clone)]
100struct ResolvedQwen2TextEncoder {
101 paths: Vec<std::path::PathBuf>,
102 vision_paths: Vec<std::path::PathBuf>,
103 is_gguf: bool,
104 variant_label: String,
105 size_bytes: u64,
106 auto_use_gpu: bool,
107}
108
109#[derive(Debug, Clone, Copy, PartialEq, Eq)]
110enum Qwen2TextEncoderUsage {
111 Sequential,
112 Resident,
113}
114
115#[derive(Debug, Clone, Copy, PartialEq, Eq)]
116enum Qwen2TextEncoderPostEncodeAction {
117 KeepGpu,
118 ParkCpu,
119 Drop,
120}
121
122#[derive(Debug, Clone, Copy)]
123struct Qwen2TextEncoderResidencyInput {
124 on_gpu: bool,
125 is_quantized: bool,
126 is_metal: bool,
127 keep_te_ram: bool,
128 prompt_cache_miss: bool,
129 transformer_resident: bool,
130 free_vram_bytes: u64,
131 required_vram_bytes: u64,
132}
133
134#[derive(Debug, Clone, Copy)]
135struct QwenTensorStats {
136 min: f32,
137 max: f32,
138 mean: f32,
139 nan_count: u64,
140 pos_inf_count: u64,
141 neg_inf_count: u64,
142 total: usize,
143}
144
145fn safetensors_is_fp8(path: &Path) -> bool {
148 if path.to_str().map(|s| s.contains("fp8")).unwrap_or(false) {
150 return true;
151 }
152 let Ok(tensors) = (unsafe { candle_core::safetensors::MmapedSafetensors::multi(&[path]) })
154 else {
155 return false;
156 };
157 for key in ["x_embedder.weight", "img_in.weight"] {
158 if let Ok(t) = tensors.load(key, &Device::Cpu) {
159 return t.dtype() == DType::F8E4M3;
160 }
161 }
162 false
163}
164
165fn text_encoder_is_fp8(paths: &[std::path::PathBuf]) -> bool {
169 if paths
171 .iter()
172 .any(|p| p.to_str().map(|s| s.contains("fp8")).unwrap_or(false))
173 {
174 return true;
175 }
176 let Some(first) = paths.first() else {
178 return false;
179 };
180 let Ok(tensors) = (unsafe { candle_core::safetensors::MmapedSafetensors::multi(&[first]) })
181 else {
182 return false;
183 };
184 for key in [
185 "model.embed_tokens.weight",
186 "model.layers.0.self_attn.q_proj.weight",
187 ] {
188 if let Ok(t) = tensors.load(key, &Device::Cpu) {
189 return t.dtype() == DType::F8E4M3;
190 }
191 }
192 false
193}
194
195struct LoadedQwenImage {
197 transformer: Option<QwenImageTransformer>,
199 text_encoder: encoders::qwen2_text::Qwen2TextEncoder,
200 vae: QwenImageVae,
201 vae_path: std::path::PathBuf,
202 transformer_cfg: QwenImageConfig,
203 device: Device,
205 vae_device: Device,
207 dtype: DType,
208}
209
210#[allow(clippy::large_enum_variant)]
211enum QwenImageTransformer {
212 BF16(QwenImageTransformer2DModel),
213 Quantized(QuantizedQwenImageTransformer2DModel),
214 Offloaded(super::offload::OffloadedQwenImageTransformer),
215}
216
217#[derive(Clone)]
218struct CachedPromptConditioning {
219 hidden_states: CachedTensor,
220 valid_len: usize,
221}
222
223impl CachedPromptConditioning {
224 fn from_parts(hidden_states: &Tensor, valid_len: usize) -> Result<Self> {
225 Ok(Self {
226 hidden_states: CachedTensor::from_tensor(hidden_states)?,
227 valid_len,
228 })
229 }
230
231 fn restore(&self, device: &Device, dtype: DType) -> Result<(Tensor, Tensor)> {
232 let hidden_states = self.hidden_states.restore(device, dtype)?;
233 let mut mask = vec![0u8; hidden_states.dim(1)?];
234 for value in &mut mask[..self.valid_len] {
235 *value = 1;
236 }
237 let attention_mask = Tensor::from_vec(mask, (1, hidden_states.dim(1)?), device)?;
238 Ok((hidden_states, attention_mask))
239 }
240}
241
242fn pad_text_conditioning(
243 hidden_states: &Tensor,
244 attention_mask: &Tensor,
245 target_len: usize,
246) -> Result<(Tensor, Tensor)> {
247 let seq_len = hidden_states.dim(1)?;
248 if seq_len == target_len {
249 return Ok((hidden_states.clone(), attention_mask.clone()));
250 }
251 if seq_len > target_len {
252 bail!("cannot shrink text conditioning from {seq_len} to {target_len}");
253 }
254
255 let hidden_dim = hidden_states.dim(2)?;
256 let pad_len = target_len - seq_len;
257 let pad_hs = Tensor::zeros(
258 (hidden_states.dim(0)?, pad_len, hidden_dim),
259 hidden_states.dtype(),
260 hidden_states.device(),
261 )?;
262 let pad_mask = Tensor::zeros(
263 (attention_mask.dim(0)?, pad_len),
264 attention_mask.dtype(),
265 attention_mask.device(),
266 )?;
267
268 Ok((
269 Tensor::cat(&[hidden_states, &pad_hs], 1)?,
270 Tensor::cat(&[attention_mask, &pad_mask], 1)?,
271 ))
272}
273
274fn align_cfg_conditioning(
275 cond_hs: &Tensor,
276 cond_mask: &Tensor,
277 uncond_hs: &Tensor,
278 uncond_mask: &Tensor,
279) -> Result<((Tensor, Tensor), (Tensor, Tensor))> {
280 let target_len = cond_hs.dim(1)?.max(uncond_hs.dim(1)?);
281 let cond = pad_text_conditioning(cond_hs, cond_mask, target_len)?;
282 let uncond = pad_text_conditioning(uncond_hs, uncond_mask, target_len)?;
283 Ok((cond, uncond))
284}
285
286impl QwenImageTransformer {
287 fn supports_cfg_batching(&self) -> bool {
288 match self {
289 Self::Quantized(model) => model.supports_cfg_batching(),
290 _ => true,
291 }
292 }
293
294 fn forward(
295 &self,
296 latents: &Tensor,
297 t: &Tensor,
298 encoder_hidden_states: &Tensor,
299 encoder_attention_mask: &Tensor,
300 ) -> Result<Tensor> {
301 match self {
302 Self::BF16(model) => {
303 Ok(model.forward(latents, t, encoder_hidden_states, encoder_attention_mask)?)
304 }
305 Self::Quantized(model) => {
306 Ok(model.forward(latents, t, encoder_hidden_states, encoder_attention_mask)?)
307 }
308 Self::Offloaded(model) => {
309 model.forward(latents, t, encoder_hidden_states, encoder_attention_mask)
310 }
311 }
312 }
313
314 fn forward_packed(
315 &self,
316 packed_latents: &Tensor,
317 t: &Tensor,
318 encoder_hidden_states: &Tensor,
319 encoder_attention_mask: &Tensor,
320 img_shapes: &[(usize, usize, usize)],
321 ) -> Result<Tensor> {
322 match self {
323 Self::BF16(model) => Ok(model.forward_packed(
324 packed_latents,
325 t,
326 encoder_hidden_states,
327 encoder_attention_mask,
328 img_shapes,
329 )?),
330 Self::Quantized(model) => Ok(model.forward_packed(
331 packed_latents,
332 t,
333 encoder_hidden_states,
334 encoder_attention_mask,
335 img_shapes,
336 )?),
337 Self::Offloaded(model) => model.forward_packed(
338 packed_latents,
339 t,
340 encoder_hidden_states,
341 encoder_attention_mask,
342 img_shapes,
343 ),
344 }
345 }
346}
347
348pub struct QwenImageEngine {
350 base: EngineBase<LoadedQwenImage>,
351 prompt_cache: Mutex<LruCache<String, CachedPromptConditioning>>,
352 offload: bool,
353 pending_placement: Option<mold_core::types::DevicePlacement>,
355 pending_loras: Vec<mold_core::LoraWeight>,
360 #[allow(dead_code)]
368 active_lora_fingerprint: Vec<QwenImageLoraFingerprint>,
369 shared_pool: Option<Arc<Mutex<crate::shared_pool::SharedPool>>>,
370}
371
372#[derive(Clone, PartialEq, Eq, Debug)]
374#[allow(dead_code)]
375struct QwenImageLoraFingerprint {
376 path_hash: u64,
377 scale_bits: u64,
378}
379
380impl QwenImageLoraFingerprint {
381 #[allow(dead_code)]
382 fn from_lora(lora: &mold_core::LoraWeight) -> Self {
383 Self {
384 path_hash: super::lora::lora_path_hash(&lora.path),
385 scale_bits: lora.scale.to_bits(),
386 }
387 }
388}
389
390#[allow(dead_code)]
391fn fingerprint_stack(loras: &[mold_core::LoraWeight]) -> Vec<QwenImageLoraFingerprint> {
392 loras
393 .iter()
394 .map(QwenImageLoraFingerprint::from_lora)
395 .collect()
396}
397
398fn effective_loras(req: &mold_core::GenerateRequest) -> Vec<mold_core::LoraWeight> {
402 const ZERO_SCALE_EPS: f64 = 1e-8;
405
406 let raw: Vec<mold_core::LoraWeight> = if let Some(plural) = &req.loras {
407 if !plural.is_empty() {
408 plural.clone()
409 } else {
410 req.lora.iter().cloned().collect()
411 }
412 } else {
413 req.lora.iter().cloned().collect()
414 };
415
416 raw.into_iter()
417 .filter(|w| {
418 let keep = w.scale.abs() > ZERO_SCALE_EPS;
419 if !keep {
420 tracing::debug!(
421 path = w.path.as_str(),
422 scale = w.scale,
423 "dropping zero-scale LoRA from effective Qwen-Image stack"
424 );
425 }
426 keep
427 })
428 .collect()
429}
430
431impl QwenImageEngine {
432 fn is_edit_family(&self) -> bool {
433 self.base.model_name.starts_with("qwen-image-edit")
434 }
435
436 fn should_preload_text_encoder(&self) -> bool {
437 !self.is_edit_family()
438 }
439
440 fn text_encoder_load_dtype(use_gpu: bool, gpu_dtype: DType) -> DType {
441 if use_gpu {
442 gpu_dtype
443 } else {
444 DType::F32
448 }
449 }
450
451 fn transformer_config(&self) -> QwenImageConfig {
452 if self.is_edit_family() {
453 QwenImageConfig::qwen_image_edit_2511()
454 } else {
455 QwenImageConfig::qwen_image_2512()
456 }
457 }
458
459 fn qwen_image_edit_prompt(prompt: &str, image_count: usize) -> String {
460 let picture_prefix = (0..image_count)
461 .map(|idx| {
462 format!(
463 "Picture {}: <|vision_start|><|image_pad|><|vision_end|>",
464 idx + 1
465 )
466 })
467 .collect::<String>();
468 format!(
469 "<|im_start|>system\n{QWEN_IMAGE_EDIT_SYSTEM_PROMPT}<|im_end|>\n<|im_start|>user\n{picture_prefix}{prompt}<|im_end|>\n<|im_start|>assistant\n"
470 )
471 }
472
473 fn qwen_image_edit_image_dims(image: &[u8], target_area: u32) -> Result<(u32, u32)> {
474 let img = image::load_from_memory(image)?;
475 Ok(fit_to_target_area(
476 img.width().max(1),
477 img.height().max(1),
478 target_area,
479 16,
480 ))
481 }
482
483 fn pack_latents_4d(latents: &Tensor) -> Result<Tensor> {
484 let (batch, channels, height, width) = latents.dims4()?;
485 let height_blocks = height / 2;
486 let width_blocks = width / 2;
487 latents
488 .reshape((batch, channels, height_blocks, 2, width_blocks, 2))?
489 .permute((0, 2, 4, 1, 3, 5))?
490 .reshape((batch, height_blocks * width_blocks, channels * 4))
491 .map_err(Into::into)
492 }
493
494 fn unpack_latents_packed(latents: &Tensor, latent_h: usize, latent_w: usize) -> Result<Tensor> {
495 let batch = latents.dim(0)?;
496 latents
497 .reshape((batch, latent_h / 2, latent_w / 2, 16, 2, 2))?
498 .permute((0, 3, 1, 4, 2, 5))?
499 .reshape((batch, 16, latent_h, latent_w))
500 .map_err(Into::into)
501 }
502
503 fn img2img_source_normalize_range() -> img_utils::NormalizeRange {
504 img_utils::NormalizeRange::MinusOneToOne
505 }
506
507 fn is_oom_error(err: &impl std::fmt::Display) -> bool {
508 let msg = err.to_string();
511 msg.contains("OUT_OF_MEMORY")
512 || msg.contains("out of memory")
513 || msg.contains("cudaErrorMemoryAllocation")
514 }
515
516 fn with_cuda_oom_cpu_fallback<T, FPrimary, FFallback, FOom>(
517 primary: FPrimary,
518 fallback: FFallback,
519 is_cuda: bool,
520 sync_device: &Device,
521 progress: &ProgressReporter,
522 oom_message: &str,
523 is_oom: FOom,
524 ) -> Result<T>
525 where
526 FPrimary: FnOnce() -> Result<T>,
527 FFallback: FnOnce() -> Result<T>,
528 FOom: Fn(&anyhow::Error) -> bool,
529 {
530 match primary() {
531 Ok(value) => Ok(value),
532 Err(err) if is_cuda && is_oom(&err) => {
533 progress.info(oom_message);
534 sync_device.synchronize()?;
535 fallback()
536 }
537 Err(err) => Err(err),
538 }
539 }
540
541 #[allow(clippy::too_many_arguments)]
542 fn with_cuda_tiled_then_cpu_fallback<T, FPrimary, FTiled, FCpu, FOom>(
543 primary: FPrimary,
544 tiled: FTiled,
545 cpu_fallback: FCpu,
546 is_cuda: bool,
547 prefer_tiled: bool,
548 sync_device: &Device,
549 progress: &ProgressReporter,
550 tiled_message: &str,
551 cpu_message: &str,
552 is_oom: FOom,
553 ) -> Result<T>
554 where
555 FPrimary: FnOnce() -> Result<T>,
556 FTiled: FnOnce() -> Result<T>,
557 FCpu: FnOnce() -> Result<T>,
558 FOom: Fn(&anyhow::Error) -> bool,
559 {
560 if is_cuda && prefer_tiled {
561 progress.info("Selecting tiled GPU VAE decode proactively");
562 match tiled() {
563 Ok(value) => return Ok(value),
564 Err(tile_err) if is_oom(&tile_err) => {
565 progress.info(cpu_message);
566 sync_device.synchronize()?;
567 return cpu_fallback();
568 }
569 Err(tile_err) => return Err(tile_err),
570 }
571 }
572
573 match primary() {
574 Ok(value) => Ok(value),
575 Err(err) if is_cuda && is_oom(&err) => {
576 progress.info(tiled_message);
577 sync_device.synchronize()?;
578 match tiled() {
579 Ok(value) => Ok(value),
580 Err(tile_err) if is_oom(&tile_err) => {
581 progress.info(cpu_message);
582 sync_device.synchronize()?;
583 cpu_fallback()
584 }
585 Err(tile_err) => Err(tile_err),
586 }
587 }
588 Err(err) => Err(err),
589 }
590 }
591
592 fn qwen_vae_decode_workspace_bytes(width: u32, height: u32) -> u64 {
593 let pixels = width as u64 * height as u64;
594 pixels.saturating_mul(4).saturating_mul(1024)
599 }
600
601 fn should_proactively_tile_vae_decode(
602 width: u32,
603 height: u32,
604 vae_is_cuda: bool,
605 free_vram_bytes: u64,
606 ) -> bool {
607 if !vae_is_cuda || free_vram_bytes == 0 {
608 return false;
609 }
610 let native_pixels = (QWEN_NATIVE_WIDTH * QWEN_NATIVE_HEIGHT) as u64;
611 let pixels = width as u64 * height as u64;
612 if pixels < native_pixels.saturating_mul(3) / 4 {
613 return false;
614 }
615 let required = VAE_DECODE_VRAM_THRESHOLD
616 .saturating_add(Self::qwen_vae_decode_workspace_bytes(width, height));
617 free_vram_bytes < required
618 }
619
620 fn qwen2_text_encoder_post_encode_action(
621 input: Qwen2TextEncoderResidencyInput,
622 ) -> Qwen2TextEncoderPostEncodeAction {
623 if !input.on_gpu {
624 return Qwen2TextEncoderPostEncodeAction::Drop;
625 }
626 if input.prompt_cache_miss
627 && input.transformer_resident
628 && !input.is_metal
629 && input.free_vram_bytes >= input.required_vram_bytes
630 {
631 return Qwen2TextEncoderPostEncodeAction::KeepGpu;
632 }
633 if input.keep_te_ram && !input.is_metal && !input.is_quantized {
634 return Qwen2TextEncoderPostEncodeAction::ParkCpu;
635 }
636 Qwen2TextEncoderPostEncodeAction::Drop
637 }
638
639 fn qwen2_hot_text_encoder_required_vram(
640 width: u32,
641 height: u32,
642 cfg_batch: u32,
643 dtype: DType,
644 ) -> u64 {
645 crate::device::activation_bytes(
646 width,
647 height,
648 cfg_batch,
649 crate::device::dtype_bytes(dtype),
650 crate::device::ActivationFamily::QwenImageDit,
651 )
652 .saturating_add(VAE_DECODE_VRAM_THRESHOLD)
653 .saturating_add(Self::qwen_vae_decode_workspace_bytes(width, height))
654 .saturating_add(QWEN2_HOT_TE_RESIDENCY_HEADROOM)
655 }
656
657 fn decode_vae_tiled(
658 latents: &Tensor,
659 vae: &QwenImageVae,
660 vae_device: &Device,
661 progress: &ProgressReporter,
662 ) -> Result<Tensor> {
663 for tile_size in QWEN_VAE_TILE_SIZES {
664 let overlap = (tile_size / 4).max(4);
665 progress.info(&format!(
666 "Retrying VAE decode with tiled GPU decode (tile {} overlap {})",
667 tile_size, overlap
668 ));
669 let config = TilingConfig {
670 tile_size,
671 overlap,
672 min_tile_size: 16,
673 };
674 let forward = |tile: &Tensor| {
675 let tile = tile.to_device(vae_device)?.to_dtype(DType::F32)?;
676 vae.decode(&tile).map_err(Into::into)
677 };
678 match upscale_with_tiling(latents, &forward, 8, &config, &Device::Cpu, progress) {
683 Ok(image) => return Ok(image),
684 Err(e) if vae_device.is_cuda() && Self::is_oom_error(&e) => {
685 if let Err(sync_err) = vae_device.synchronize() {
686 tracing::warn!(
687 "failed to synchronize CUDA device after tiled VAE OOM: {sync_err}"
688 );
689 }
690 }
691 Err(e) => return Err(e),
692 }
693 }
694
695 bail!("tiled VAE decode still ran out of memory")
696 }
697
698 fn decode_vae_with_fallback<F>(
699 latents: &Tensor,
700 vae: &QwenImageVae,
701 vae_device: &Device,
702 sync_device: &Device,
703 progress: &ProgressReporter,
704 prefer_tiled: bool,
705 load_cpu_vae: F,
706 ) -> Result<Tensor>
707 where
708 F: FnOnce() -> Result<QwenImageVae>,
709 {
710 let decode_latents = latents.to_device(vae_device)?.to_dtype(DType::F32)?;
711 Self::debug_tensor_stats("latents_pre_vae", &decode_latents);
712 Self::with_cuda_tiled_then_cpu_fallback(
713 || vae.decode(&decode_latents).map_err(Into::into),
714 || Self::decode_vae_tiled(latents, vae, vae_device, progress),
715 || {
716 let cpu_vae = load_cpu_vae()?;
717 let cpu_latents = latents.to_device(&Device::Cpu)?.to_dtype(DType::F32)?;
718 cpu_vae.decode(&cpu_latents).map_err(Into::into)
719 },
720 vae_device.is_cuda(),
721 prefer_tiled,
722 sync_device,
723 progress,
724 "VAE decode OOM on GPU — retrying with tiled GPU decode",
725 "Tiled GPU VAE decode OOM — retrying on CPU",
726 Self::is_oom_error,
727 )
728 }
729
730 #[allow(clippy::too_many_arguments)]
732 fn encode_vae_with_fallback(
733 source_bytes: &[u8],
734 width: u32,
735 height: u32,
736 vae: &QwenImageVae,
737 vae_device: &Device,
738 sync_device: &Device,
739 progress: &ProgressReporter,
740 load_cpu_vae: impl FnOnce() -> Result<QwenImageVae>,
741 ) -> Result<Tensor> {
742 progress.stage_start("Encoding source image (VAE)");
743 let encode_start = Instant::now();
744
745 let source_tensor = img_utils::decode_source_image(
747 source_bytes,
748 width,
749 height,
750 Self::img2img_source_normalize_range(),
751 vae_device,
752 DType::F32,
753 )?;
754
755 let result = Self::with_cuda_oom_cpu_fallback(
756 || vae.encode(&source_tensor).map_err(Into::into),
757 || {
758 let cpu_vae = load_cpu_vae()?;
759 let cpu_source = img_utils::decode_source_image(
760 source_bytes,
761 width,
762 height,
763 Self::img2img_source_normalize_range(),
764 &Device::Cpu,
765 DType::F32,
766 )?;
767 cpu_vae.encode(&cpu_source).map_err(Into::into)
768 },
769 vae_device.is_cuda(),
770 sync_device,
771 progress,
772 "VAE encode OOM on GPU — retrying on CPU",
773 Self::is_oom_error,
774 );
775
776 progress.stage_done("Encoding source image (VAE)", encode_start.elapsed());
777 result
778 }
779
780 fn choose_text_encoder_source(
781 preference: Option<&str>,
782 is_cuda: bool,
783 is_metal: bool,
784 free_vram: u64,
785 bf16_size_bytes: u64,
786 _usage: Qwen2TextEncoderUsage,
787 ) -> Result<ResolvedQwen2TextEncoder> {
788 match preference {
789 Some(tag) if tag != "auto" && tag != "bf16" => {
790 let variant = mold_core::manifest::find_qwen2_vl_variant(tag).ok_or_else(|| {
791 anyhow::anyhow!(
792 "unknown Qwen2.5-VL variant '{}'. Valid: bf16, auto, q8, q6, q5, q4, q3, q2",
793 tag
794 )
795 })?;
796 Ok(ResolvedQwen2TextEncoder {
797 paths: vec![],
798 vision_paths: vec![],
799 is_gguf: true,
800 variant_label: variant.tag.to_string(),
801 size_bytes: variant.size_bytes,
802 auto_use_gpu: should_use_gpu(
803 is_cuda,
804 is_metal,
805 free_vram,
806 qwen2_vram_threshold(variant.size_bytes),
807 ),
808 })
809 }
810 Some("bf16") => Ok(ResolvedQwen2TextEncoder {
811 paths: vec![],
812 vision_paths: vec![],
813 is_gguf: false,
814 variant_label: "bf16".to_string(),
815 size_bytes: bf16_size_bytes,
816 auto_use_gpu: should_use_gpu(
817 is_cuda,
818 is_metal,
819 free_vram,
820 QWEN2_FP16_VRAM_THRESHOLD,
821 ),
822 }),
823 _ if is_metal => {
824 for tag in ["q6", "q4"] {
825 let variant = mold_core::manifest::find_qwen2_vl_variant(tag)
826 .expect("known Metal auto qwen2 variant missing");
827 if fits_in_memory(
828 is_cuda,
829 is_metal,
830 free_vram,
831 qwen2_vram_threshold(variant.size_bytes),
832 ) {
833 return Ok(ResolvedQwen2TextEncoder {
834 paths: vec![],
835 vision_paths: vec![],
836 is_gguf: true,
837 variant_label: variant.tag.to_string(),
838 size_bytes: variant.size_bytes,
839 auto_use_gpu: true,
840 });
841 }
842 }
843 let fallback = mold_core::manifest::find_qwen2_vl_variant("q4")
844 .expect("known Metal fallback qwen2 variant missing");
845 Ok(ResolvedQwen2TextEncoder {
846 paths: vec![],
847 vision_paths: vec![],
848 is_gguf: true,
849 variant_label: fallback.tag.to_string(),
850 size_bytes: fallback.size_bytes,
851 auto_use_gpu: true,
852 })
853 }
854 _ => {
855 let bf16_on_gpu =
856 should_use_gpu(is_cuda, is_metal, free_vram, QWEN2_FP16_VRAM_THRESHOLD);
857 if bf16_on_gpu {
858 return Ok(ResolvedQwen2TextEncoder {
859 paths: vec![],
860 vision_paths: vec![],
861 is_gguf: false,
862 variant_label: "bf16".to_string(),
863 size_bytes: bf16_size_bytes,
864 auto_use_gpu: true,
865 });
866 }
867
868 if is_cuda {
869 let fallback_tag = "q4";
870 let fallback = mold_core::manifest::find_qwen2_vl_variant(fallback_tag)
871 .expect("known CUDA fallback qwen2 variant missing");
872 return Ok(ResolvedQwen2TextEncoder {
873 paths: vec![],
874 vision_paths: vec![],
875 is_gguf: true,
876 variant_label: fallback.tag.to_string(),
877 size_bytes: fallback.size_bytes,
878 auto_use_gpu: fits_in_memory(
879 is_cuda,
880 is_metal,
881 free_vram,
882 qwen2_vram_threshold(fallback.size_bytes),
883 ),
884 });
885 }
886
887 Ok(ResolvedQwen2TextEncoder {
888 paths: vec![],
889 vision_paths: vec![],
890 is_gguf: false,
891 variant_label: "bf16".to_string(),
892 size_bytes: bf16_size_bytes,
893 auto_use_gpu: false,
894 })
895 }
896 }
897 }
898
899 fn tensor_stats(tensor: &Tensor) -> Result<QwenTensorStats> {
900 let t = tensor.to_dtype(DType::F32)?;
901 let values = t.flatten_all()?.to_vec1::<f32>()?;
902 let mut min = f32::INFINITY;
903 let mut max = f32::NEG_INFINITY;
904 let mut sum = 0.0f64;
905 let mut finite_count = 0usize;
906 let mut nan_count = 0u64;
907 let mut pos_inf_count = 0u64;
908 let mut neg_inf_count = 0u64;
909 for value in &values {
910 if value.is_nan() {
911 nan_count += 1;
912 } else if *value == f32::INFINITY {
913 pos_inf_count += 1;
914 } else if *value == f32::NEG_INFINITY {
915 neg_inf_count += 1;
916 } else {
917 min = min.min(*value);
918 max = max.max(*value);
919 sum += *value as f64;
920 finite_count += 1;
921 }
922 }
923 let mean = if finite_count == 0 {
924 f32::NAN
925 } else {
926 (sum / finite_count as f64) as f32
927 };
928 if finite_count == 0 {
929 min = f32::NAN;
930 max = f32::NAN;
931 }
932 Ok(QwenTensorStats {
933 min,
934 max,
935 mean,
936 nan_count,
937 pos_inf_count,
938 neg_inf_count,
939 total: values.len(),
940 })
941 }
942
943 fn format_tensor_stats(name: &str, stats: QwenTensorStats) -> String {
944 format!(
945 "[qwen-debug] {name}: min={:.4} max={:.4} mean={:.4} NaN={}/{} ({:.1}%) +Inf={} -Inf={}",
946 stats.min,
947 stats.max,
948 stats.mean,
949 stats.nan_count,
950 stats.total,
951 stats.nan_count as f64 / stats.total.max(1) as f64 * 100.0,
952 stats.pos_inf_count,
953 stats.neg_inf_count
954 )
955 }
956
957 fn near_black_image_stats(stats: QwenTensorStats) -> bool {
958 if stats.nan_count > 0
959 || stats.pos_inf_count > 0
960 || stats.neg_inf_count > 0
961 || !stats.min.is_finite()
962 || !stats.max.is_finite()
963 || !stats.mean.is_finite()
964 {
965 return false;
966 }
967 let scale = if stats.max <= 1.0 { 1.0 } else { 255.0 };
968 stats.max <= 0.02 * scale && stats.mean <= 0.01 * scale
969 }
970
971 fn validate_qwen_tensor_boundary(name: &str, tensor: &Tensor) -> Result<QwenTensorStats> {
972 let stats = Self::tensor_stats(tensor)?;
973 if stats.nan_count > 0
974 || stats.pos_inf_count > 0
975 || stats.neg_inf_count > 0
976 || !stats.min.is_finite()
977 || !stats.max.is_finite()
978 || !stats.mean.is_finite()
979 {
980 bail!(
981 "Qwen diagnostic boundary '{name}' contains non-finite values: {}",
982 Self::format_tensor_stats(name, stats)
983 );
984 }
985 Ok(stats)
986 }
987
988 fn debug_tensor_stats(name: &str, tensor: &Tensor) {
989 if std::env::var_os("MOLD_QWEN_DEBUG").is_none() {
990 return;
991 }
992 match Self::tensor_stats(tensor) {
993 Ok(stats) => eprintln!("{}", Self::format_tensor_stats(name, stats)),
994 Err(err) => eprintln!("[qwen-debug] {name}: <failed: {err}>"),
995 }
996 }
997
998 pub fn new(
999 model_name: String,
1000 paths: ModelPaths,
1001 load_strategy: LoadStrategy,
1002 gpu_ordinal: usize,
1003 offload: bool,
1004 shared_pool: Option<Arc<Mutex<crate::shared_pool::SharedPool>>>,
1005 ) -> Self {
1006 Self {
1007 base: EngineBase::new(model_name, paths, load_strategy, gpu_ordinal),
1008 prompt_cache: Mutex::new(LruCache::new(DEFAULT_PROMPT_CACHE_CAPACITY)),
1009 offload,
1010 pending_placement: None,
1011 pending_loras: Vec::new(),
1012 active_lora_fingerprint: Vec::new(),
1013 shared_pool,
1014 }
1015 }
1016
1017 fn load_text_tokenizer(&self, tokenizer_path: &Path) -> Result<Arc<Tokenizer>> {
1018 if let Some(shared_pool) = &self.shared_pool {
1019 return shared_pool.lock().unwrap().load_tokenizer(tokenizer_path);
1020 }
1021 Tokenizer::from_file(tokenizer_path)
1022 .map(Arc::new)
1023 .map_err(|e| anyhow::anyhow!("failed to load Qwen2.5 tokenizer: {e}"))
1024 }
1025
1026 fn encode_prompt_cached(
1027 progress: &ProgressReporter,
1028 prompt_cache: &Mutex<LruCache<String, CachedPromptConditioning>>,
1029 text_encoder: &mut encoders::qwen2_text::Qwen2TextEncoder,
1030 prompt: &str,
1031 device: &Device,
1032 dtype: DType,
1033 ) -> Result<(Tensor, Tensor)> {
1034 let cache_key = prompt_text_key(prompt);
1035 if let Some(cached) = prompt_cache
1036 .lock()
1037 .expect("cache poisoned")
1038 .get_cloned(&cache_key)
1039 {
1040 progress.cache_hit("prompt conditioning");
1041 return cached.restore(device, dtype);
1042 }
1043
1044 progress.stage_start("Encoding prompt (Qwen2.5)");
1045 let encode_start = Instant::now();
1046 let (hidden_states, _attention_mask, valid_len) =
1047 text_encoder.encode(prompt, device, dtype)?;
1048 progress.stage_done("Encoding prompt (Qwen2.5)", encode_start.elapsed());
1049
1050 prompt_cache.lock().expect("cache poisoned").insert(
1051 cache_key,
1052 CachedPromptConditioning::from_parts(&hidden_states, valid_len)?,
1053 );
1054
1055 let mut mask = vec![0u8; hidden_states.dim(1)?];
1056 for value in &mut mask[..valid_len] {
1057 *value = 1;
1058 }
1059 let attention_mask = Tensor::from_vec(mask, (1, hidden_states.dim(1)?), device)?;
1060 Ok((hidden_states, attention_mask))
1061 }
1062
1063 fn spill_conditioning_to_cpu(
1064 hidden_states: Tensor,
1065 attention_mask: Tensor,
1066 ) -> Result<(Tensor, Tensor)> {
1067 Ok((
1068 hidden_states
1069 .to_device(&Device::Cpu)?
1070 .to_dtype(DType::F32)?,
1071 attention_mask.to_device(&Device::Cpu)?,
1072 ))
1073 }
1074
1075 fn maybe_spill_conditioning(
1076 use_cpu_staging: bool,
1077 hidden_states: Tensor,
1078 attention_mask: Tensor,
1079 ) -> Result<(Tensor, Tensor)> {
1080 if use_cpu_staging {
1081 Self::spill_conditioning_to_cpu(hidden_states, attention_mask)
1082 } else {
1083 Ok((hidden_states, attention_mask))
1084 }
1085 }
1086
1087 fn transformer_paths(&self) -> Vec<std::path::PathBuf> {
1089 if !self.base.paths.transformer_shards.is_empty() {
1090 self.base.paths.transformer_shards.clone()
1091 } else {
1092 vec![self.base.paths.transformer.clone()]
1093 }
1094 }
1095
1096 fn detect_is_quantized(&self) -> bool {
1097 self.base
1098 .paths
1099 .transformer
1100 .extension()
1101 .and_then(|e| e.to_str())
1102 .map(|e| e.eq_ignore_ascii_case("gguf"))
1103 .unwrap_or(false)
1104 }
1105
1106 fn validate_paths(&self) -> Result<std::path::PathBuf> {
1108 let text_tokenizer_path =
1109 self.base.paths.text_tokenizer.as_ref().ok_or_else(|| {
1110 anyhow::anyhow!("text tokenizer path required for Qwen-Image models")
1111 })?;
1112 if !text_tokenizer_path.exists() {
1113 bail!(
1114 "text tokenizer file not found: {}",
1115 text_tokenizer_path.display()
1116 );
1117 }
1118
1119 let xformer_paths = self.transformer_paths();
1120 for path in &xformer_paths {
1121 if !path.exists() {
1122 bail!("transformer file not found: {}", path.display());
1123 }
1124 }
1125 if !self.base.paths.vae.exists() {
1126 bail!("VAE file not found: {}", self.base.paths.vae.display());
1127 }
1128
1129 Ok(text_tokenizer_path.clone())
1130 }
1131
1132 fn quantized_cuda_cfg_headroom(width: usize, height: usize) -> u64 {
1133 let native_pixels = (QWEN_NATIVE_WIDTH * QWEN_NATIVE_HEIGHT) as f64;
1134 let pixels = (width.max(1) * height.max(1)) as f64;
1135 let scaled =
1136 (QWEN_GGUF_NATIVE_CFG_HEADROOM as f64 * (pixels / native_pixels)).round() as u64;
1137 scaled.max(QWEN_GGUF_MIN_CFG_HEADROOM)
1138 }
1139
1140 fn should_split_cfg_quantized_cuda(
1141 transformer_size: u64,
1142 free_vram: u64,
1143 width: usize,
1144 height: usize,
1145 ) -> bool {
1146 if free_vram == 0 {
1147 return true;
1150 }
1151 let estimated_peak =
1152 transformer_size.saturating_add(Self::quantized_cuda_cfg_headroom(width, height));
1153 estimated_peak > free_vram
1154 }
1155
1156 fn load_transformer(
1158 &self,
1159 device: &Device,
1160 dtype: DType,
1161 cfg: &QwenImageConfig,
1162 width: usize,
1163 height: usize,
1164 ) -> Result<QwenImageTransformer> {
1165 let active_loras = &self.pending_loras;
1166 let has_lora = !active_loras.is_empty();
1167 if self.detect_is_quantized() {
1168 let transformer_size = std::fs::metadata(&self.base.paths.transformer)
1169 .map(|m| m.len())
1170 .unwrap_or(0);
1171 let free = usable_free_vram_bytes(self.base.gpu_ordinal).unwrap_or(0);
1173 let split_cfg_for_memory = device.is_cuda()
1174 && (self.offload
1175 || Self::should_split_cfg_quantized_cuda(
1176 transformer_size,
1177 free,
1178 width,
1179 height,
1180 ));
1181 if self.offload && device.is_cuda() {
1182 self.base.progress.info(
1183 "Quantized Qwen CUDA offload requested — using low-memory split-CFG mode until GGUF block offload lands",
1184 );
1185 } else if split_cfg_for_memory {
1186 let estimated_peak = transformer_size
1187 .saturating_add(Self::quantized_cuda_cfg_headroom(width, height));
1188 self.base.progress.info(&format!(
1189 "Using low-memory quantized Qwen CUDA path (est. peak {}, {} free at {}x{})",
1190 fmt_gb(estimated_peak),
1191 fmt_gb(free),
1192 width,
1193 height,
1194 ));
1195 }
1196 let vb = if has_lora {
1197 let adapters = super::lora::load_lora_adapters(active_loras, &self.base.progress)?;
1198 let specs: Vec<super::lora::QwenImageLoraSpec<'_>> = adapters
1199 .iter()
1200 .zip(active_loras.iter())
1201 .map(|(adapter, w)| super::lora::QwenImageLoraSpec {
1202 adapter: adapter.as_ref(),
1203 scale: w.scale,
1204 path_hash: super::lora::lora_path_hash(&w.path),
1205 })
1206 .collect();
1207 super::lora::gguf_lora_var_builder(
1208 &self.base.paths.transformer,
1209 &specs,
1210 device,
1211 &self.base.progress,
1212 None,
1213 )?
1214 } else {
1215 quantized_var_builder::VarBuilder::from_gguf(&self.base.paths.transformer, device)?
1216 };
1217 Ok(QwenImageTransformer::Quantized(
1218 QuantizedQwenImageTransformer2DModel::new(cfg, vb, device, !split_cfg_for_memory)?,
1219 ))
1220 } else {
1221 let xformer_paths = self.transformer_paths();
1222 let is_fp8 = xformer_paths
1223 .first()
1224 .map(|p| safetensors_is_fp8(p))
1225 .unwrap_or(false);
1226
1227 let mem_size: u64 = xformer_paths
1231 .iter()
1232 .filter_map(|p| std::fs::metadata(p).ok())
1233 .map(|m| m.len())
1234 .sum();
1235 let free = usable_free_vram_bytes(self.base.gpu_ordinal).unwrap_or(0);
1237 let activation_budget = crate::device::activation_bytes(
1240 width as u32,
1241 height as u32,
1242 2,
1243 crate::device::dtype_bytes(dtype),
1244 crate::device::ActivationFamily::QwenImageDit,
1245 );
1246 let use_offload =
1247 self.offload || crate::device::should_offload(mem_size, free, activation_budget);
1248
1249 if is_fp8 {
1250 self.base
1251 .progress
1252 .info("Detected FP8 safetensors — loading with scale dequantization");
1253 }
1254
1255 if use_offload {
1256 if has_lora {
1257 bail!(
1258 "Qwen-Image LoRA support is not yet wired through the block-offload \
1259 transformer path. Disable offload (drop --offload / unset MOLD_OFFLOAD), \
1260 or pick a checkpoint that fits without offload, to use LoRAs."
1261 );
1262 }
1263 let (gpu_vb, cpu_vb) = if is_fp8 {
1265 let gpu = crate::weight_loader::load_fp8_safetensors(
1266 &xformer_paths,
1267 device,
1268 "Qwen-Image transformer (offload, GPU)",
1269 &self.base.progress,
1270 )?;
1271 let cpu = crate::weight_loader::load_fp8_safetensors(
1272 &xformer_paths,
1273 &Device::Cpu,
1274 "Qwen-Image transformer (offload, CPU)",
1275 &self.base.progress,
1276 )?;
1277 (gpu, cpu)
1278 } else {
1279 let gpu = crate::weight_loader::load_safetensors_with_progress(
1280 &xformer_paths,
1281 dtype,
1282 device,
1283 "Qwen-Image transformer (offload, GPU)",
1284 &self.base.progress,
1285 )?;
1286 let cpu = unsafe {
1287 candle_nn::VarBuilder::from_mmaped_safetensors(
1288 &xformer_paths
1289 .iter()
1290 .map(|p| p.as_path())
1291 .collect::<Vec<_>>(),
1292 DType::BF16,
1293 &Device::Cpu,
1294 )?
1295 };
1296 (gpu, cpu)
1297 };
1298 Ok(QwenImageTransformer::Offloaded(
1299 super::offload::OffloadedQwenImageTransformer::load(
1300 gpu_vb,
1301 cpu_vb,
1302 cfg,
1303 device,
1304 self.base.gpu_ordinal,
1305 activation_budget,
1306 &self.base.progress,
1307 )?,
1308 ))
1309 } else {
1310 let xformer_vb = if has_lora {
1311 self.build_bf16_lora_var_builder(
1312 &xformer_paths,
1313 dtype,
1314 device,
1315 is_fp8,
1316 active_loras,
1317 )?
1318 } else if is_fp8 {
1319 crate::weight_loader::load_fp8_safetensors(
1320 &xformer_paths,
1321 device,
1322 "Qwen-Image transformer",
1323 &self.base.progress,
1324 )?
1325 } else {
1326 crate::weight_loader::load_safetensors_with_progress(
1327 &xformer_paths,
1328 dtype,
1329 device,
1330 "Qwen-Image transformer",
1331 &self.base.progress,
1332 )?
1333 };
1334 Ok(QwenImageTransformer::BF16(
1335 QwenImageTransformer2DModel::new(cfg, xformer_vb)?,
1336 ))
1337 }
1338 }
1339 }
1340
1341 fn build_bf16_lora_var_builder<'a>(
1346 &self,
1347 xformer_paths: &[std::path::PathBuf],
1348 dtype: DType,
1349 device: &Device,
1350 is_fp8: bool,
1351 loras: &[mold_core::LoraWeight],
1352 ) -> Result<candle_nn::VarBuilder<'a>> {
1353 let adapters = super::lora::load_lora_adapters(loras, &self.base.progress)?;
1354 let specs: Vec<super::lora::QwenImageLoraSpec<'_>> = adapters
1355 .iter()
1356 .zip(loras.iter())
1357 .map(|(adapter, w)| super::lora::QwenImageLoraSpec {
1358 adapter: adapter.as_ref(),
1359 scale: w.scale,
1360 path_hash: super::lora::lora_path_hash(&w.path),
1361 })
1362 .collect();
1363
1364 let path_refs: Vec<&std::path::Path> = xformer_paths.iter().map(|p| p.as_path()).collect();
1365 let tensors = unsafe { candle_core::safetensors::MmapedSafetensors::multi(&path_refs)? };
1366 let inner: Box<dyn candle_nn::var_builder::SimpleBackend> = if is_fp8 {
1367 self.base
1372 .progress
1373 .info("Detected FP8 safetensors — loading with LoRA-merging wrapper");
1374 Box::new(crate::weight_loader::NativeFp8Backend::from_mmap(tensors))
1375 } else {
1376 Box::new(tensors)
1379 };
1380
1381 let wrapped =
1382 super::lora::wrap_backend_with_lora(inner, &specs, &self.base.progress, None)?;
1383
1384 let target_dtype = if is_fp8 { DType::BF16 } else { dtype };
1385 Ok(candle_nn::VarBuilder::from_backend(
1386 wrapped,
1387 target_dtype,
1388 device.clone(),
1389 ))
1390 }
1391
1392 fn load_vae(&self, device: &Device, dtype: DType) -> Result<QwenImageVae> {
1394 let vb = self.load_vae_var_builder(device, dtype)?;
1395 Ok(QwenImageVae::from_var_builder(vb, device, dtype)?)
1396 }
1397
1398 fn load_vae_cpu_tensors(&self) -> Result<Option<Arc<HashMap<String, Tensor>>>> {
1399 let Some(shared_pool) = &self.shared_pool else {
1400 return Ok(None);
1401 };
1402 shared_pool
1403 .lock()
1404 .unwrap()
1405 .load_safetensors_cpu_tensors(std::slice::from_ref(&self.base.paths.vae))
1406 }
1407
1408 fn load_vae_var_builder<'a>(
1409 &self,
1410 device: &Device,
1411 dtype: DType,
1412 ) -> Result<candle_nn::VarBuilder<'a>> {
1413 if let Some(tensors) = self.load_vae_cpu_tensors()? {
1414 return Ok(encoders::park::varbuilder_from_parked(
1415 tensors.as_ref(),
1416 dtype,
1417 device,
1418 ));
1419 }
1420
1421 crate::weight_loader::load_safetensors_with_progress(
1422 std::slice::from_ref(&self.base.paths.vae),
1423 dtype,
1424 device,
1425 "Qwen-Image VAE",
1426 &self.base.progress,
1427 )
1428 }
1429
1430 fn resolve_text_encoder_source(
1435 &self,
1436 gpu_device: &Device,
1437 free_vram: u64,
1438 usage: Qwen2TextEncoderUsage,
1439 ) -> Result<ResolvedQwen2TextEncoder> {
1440 let preference = std::env::var("MOLD_QWEN2_VARIANT").ok();
1441 self.resolve_text_encoder_source_with_preference(
1442 gpu_device,
1443 free_vram,
1444 usage,
1445 preference.as_deref(),
1446 )
1447 }
1448
1449 fn resolve_text_encoder_source_with_preference(
1450 &self,
1451 gpu_device: &Device,
1452 free_vram: u64,
1453 usage: Qwen2TextEncoderUsage,
1454 preference: Option<&str>,
1455 ) -> Result<ResolvedQwen2TextEncoder> {
1456 let is_cuda = gpu_device.is_cuda();
1457 let is_metal = gpu_device.is_metal();
1458 let bf16_size_bytes = self
1459 .base
1460 .paths
1461 .text_encoder_files
1462 .iter()
1463 .filter_map(|p| std::fs::metadata(p).ok())
1464 .map(|m| m.len())
1465 .sum();
1466 if self.is_edit_family() {
1467 let mut resolved = Self::choose_text_encoder_source(
1468 preference,
1469 is_cuda,
1470 is_metal,
1471 free_vram,
1472 bf16_size_bytes,
1473 Qwen2TextEncoderUsage::Resident,
1474 )?;
1475 resolved.vision_paths = self.base.paths.text_encoder_files.clone();
1476 if resolved.is_gguf {
1477 let variant = mold_core::manifest::find_qwen2_vl_variant(&resolved.variant_label)
1478 .ok_or_else(|| {
1479 anyhow::anyhow!("unknown Qwen2.5-VL variant '{}'", resolved.variant_label)
1480 })?;
1481 resolved.paths = vec![
1482 crate::encoders::variant_resolution::resolve_qwen2_vl_gguf_path(
1483 &self.base.progress,
1484 variant,
1485 )?,
1486 ];
1487 } else {
1488 resolved.paths = self.base.paths.text_encoder_files.clone();
1489 }
1490 return Ok(resolved);
1491 }
1492 let mut resolved = Self::choose_text_encoder_source(
1493 preference,
1494 is_cuda,
1495 is_metal,
1496 free_vram,
1497 bf16_size_bytes,
1498 usage,
1499 )?;
1500
1501 if resolved.is_gguf {
1502 let variant = mold_core::manifest::find_qwen2_vl_variant(&resolved.variant_label)
1503 .ok_or_else(|| {
1504 anyhow::anyhow!("unknown Qwen2.5-VL variant '{}'", resolved.variant_label)
1505 })?;
1506 resolved.paths = vec![
1507 crate::encoders::variant_resolution::resolve_qwen2_vl_gguf_path(
1508 &self.base.progress,
1509 variant,
1510 )?,
1511 ];
1512 } else {
1513 resolved.paths = self.base.paths.text_encoder_files.clone();
1514 }
1515 resolved.vision_paths = vec![];
1516
1517 match preference {
1518 Some(tag) if tag != "auto" && tag != "bf16" => self.base.progress.info(&format!(
1519 "Using quantized Qwen2.5-VL {} ({}) on {} (explicit)",
1520 resolved.variant_label,
1521 fmt_gb(resolved.size_bytes),
1522 if resolved.auto_use_gpu { "GPU" } else { "CPU" },
1523 )),
1524 Some("bf16") => {}
1525 _ if is_metal && resolved.is_gguf && resolved.variant_label == "q6" => self
1526 .base
1527 .progress
1528 .info(&format!(
1529 "Metal auto mode selected quantized Qwen2.5-VL {} ({}) for lower memory pressure",
1530 resolved.variant_label,
1531 fmt_gb(resolved.size_bytes),
1532 )),
1533 _ if is_metal && resolved.is_gguf => self.base.progress.info(&format!(
1534 "Metal auto mode forcing quantized Qwen2.5-VL {} ({}) to avoid BF16 memory pressure",
1535 resolved.variant_label,
1536 fmt_gb(resolved.size_bytes),
1537 )),
1538 _ if is_cuda && resolved.is_gguf && resolved.auto_use_gpu => self.base.progress.info(
1539 &format!(
1540 "CUDA auto mode selected quantized Qwen2.5-VL {} ({}) on GPU",
1541 resolved.variant_label,
1542 fmt_gb(resolved.size_bytes),
1543 ),
1544 ),
1545 _ if is_cuda && resolved.is_gguf => self.base.progress.info(&format!(
1546 "CUDA auto mode selected quantized Qwen2.5-VL {} ({}) on CPU to avoid large BF16 host residency",
1547 resolved.variant_label,
1548 fmt_gb(resolved.size_bytes),
1549 )),
1550 _ => {}
1551 }
1552
1553 Ok(resolved)
1554 }
1555
1556 fn can_keep_transformer_hot_for_vae(loaded: &LoadedQwenImage) -> bool {
1557 Self::qwen_transformer_can_stay_hot_for_vae(
1558 loaded.device.is_cuda(),
1559 loaded.vae_device.is_cuda(),
1560 matches!(
1561 loaded.transformer.as_ref(),
1562 Some(QwenImageTransformer::Quantized(_))
1563 ),
1564 )
1565 }
1566
1567 fn qwen_transformer_can_stay_hot_for_vae(
1568 transformer_is_cuda: bool,
1569 vae_is_cuda: bool,
1570 transformer_is_quantized: bool,
1571 ) -> bool {
1572 transformer_is_cuda && vae_is_cuda && transformer_is_quantized
1573 }
1574
1575 fn decode_vae_gpu_only(
1576 latents: &Tensor,
1577 vae: &QwenImageVae,
1578 vae_device: &Device,
1579 sync_device: &Device,
1580 progress: &ProgressReporter,
1581 prefer_tiled: bool,
1582 ) -> Result<Tensor> {
1583 if vae_device.is_cuda() && prefer_tiled {
1584 progress.info("Selecting tiled GPU VAE decode proactively");
1585 return Self::decode_vae_tiled(latents, vae, vae_device, progress);
1586 }
1587
1588 let decode_latents = latents.to_device(vae_device)?.to_dtype(DType::F32)?;
1589 match vae.decode(&decode_latents) {
1590 Ok(image) => Ok(image),
1591 Err(e) if vae_device.is_cuda() && Self::is_oom_error(&e) => {
1592 progress.info(
1593 "Resident-transformer VAE decode OOM on GPU — retrying with tiled GPU decode before dropping transformer",
1594 );
1595 sync_device.synchronize()?;
1596 Self::decode_vae_tiled(latents, vae, vae_device, progress)
1597 }
1598 Err(e) => Err(e.into()),
1599 }
1600 }
1601
1602 fn load_text_encoder(
1603 &self,
1604 resolved: &ResolvedQwen2TextEncoder,
1605 tokenizer_path: &std::path::PathBuf,
1606 tokenizer: Arc<Tokenizer>,
1607 device: &Device,
1608 dtype: DType,
1609 preload_weights: bool,
1610 ) -> Result<encoders::qwen2_text::Qwen2TextEncoder> {
1611 if resolved.is_gguf {
1612 if preload_weights {
1613 encoders::qwen2_text::Qwen2TextEncoder::load_gguf_with_tokenizer(
1614 &resolved.paths[0],
1615 tokenizer_path,
1616 Some(tokenizer),
1617 device,
1618 dtype,
1619 &resolved.vision_paths,
1620 &self.base.progress,
1621 )
1622 } else {
1623 encoders::qwen2_text::Qwen2TextEncoder::prepare_gguf_with_tokenizer(
1624 &resolved.paths[0],
1625 tokenizer_path,
1626 Some(tokenizer),
1627 device,
1628 dtype,
1629 &resolved.vision_paths,
1630 )
1631 }
1632 } else {
1633 let is_fp8 = text_encoder_is_fp8(&resolved.paths);
1634 if is_fp8 {
1635 self.base
1636 .progress
1637 .info("Detected FP8 text encoder — loading as BF16 on GPU");
1638 }
1639 if preload_weights {
1640 encoders::qwen2_text::Qwen2TextEncoder::load_bf16_with_tokenizer(
1641 &resolved.paths,
1642 tokenizer_path,
1643 Some(tokenizer),
1644 device,
1645 dtype,
1646 self.is_edit_family(),
1647 &self.base.progress,
1648 )
1649 } else {
1650 encoders::qwen2_text::Qwen2TextEncoder::prepare_bf16_with_tokenizer(
1651 &resolved.paths,
1652 tokenizer_path,
1653 Some(tokenizer),
1654 device,
1655 dtype,
1656 self.is_edit_family(),
1657 )
1658 }
1659 }
1660 }
1661
1662 fn resolve_text_encoder_plan(
1664 &self,
1665 gpu_device: &Device,
1666 resolved: &ResolvedQwen2TextEncoder,
1667 free_vram: u64,
1668 ) -> (Qwen2TextEncoderPlan, String) {
1669 let is_cuda = gpu_device.is_cuda();
1670 let is_metal = gpu_device.is_metal();
1671 let plan = Self::qwen2_text_encoder_plan_for_mode(
1672 Qwen2TextEncoderMode::from_env(),
1673 is_cuda,
1674 is_metal,
1675 resolved,
1676 );
1677 let label = if plan.use_gpu { "GPU" } else { "CPU" };
1678 if plan.use_cpu_staging {
1679 self.base
1680 .progress
1681 .info("Qwen2.5 text encoder on GPU with CPU staging after encoding");
1682 } else if !plan.use_gpu {
1683 if resolved.is_gguf {
1684 self.base.progress.info(&format!(
1685 "Qwen2.5 text encoder on CPU ({} variant {}, {} free)",
1686 resolved.variant_label,
1687 fmt_gb(resolved.size_bytes),
1688 fmt_gb(free_vram),
1689 ));
1690 } else if is_metal || is_cuda {
1691 self.base.progress.info(&format!(
1692 "Qwen2.5 text encoder on CPU ({} free < {} threshold)",
1693 fmt_gb(free_vram),
1694 fmt_gb(QWEN2_FP16_VRAM_THRESHOLD),
1695 ));
1696 }
1697 }
1698 (plan, label.to_string())
1699 }
1700
1701 fn qwen2_text_encoder_plan_for_mode(
1702 mode: Qwen2TextEncoderMode,
1703 is_cuda: bool,
1704 is_metal: bool,
1705 resolved: &ResolvedQwen2TextEncoder,
1706 ) -> Qwen2TextEncoderPlan {
1707 match mode {
1708 Qwen2TextEncoderMode::Gpu => Qwen2TextEncoderPlan {
1709 use_gpu: is_cuda || is_metal,
1710 use_cpu_staging: false,
1711 },
1712 Qwen2TextEncoderMode::CpuStage => Qwen2TextEncoderPlan {
1713 use_gpu: is_cuda || is_metal,
1714 use_cpu_staging: is_cuda || is_metal,
1715 },
1716 Qwen2TextEncoderMode::Cpu => Qwen2TextEncoderPlan {
1717 use_gpu: false,
1718 use_cpu_staging: false,
1719 },
1720 Qwen2TextEncoderMode::Auto => Qwen2TextEncoderPlan {
1721 use_gpu: resolved.auto_use_gpu,
1722 use_cpu_staging: is_metal && resolved.auto_use_gpu && !resolved.is_gguf,
1723 },
1724 }
1725 }
1726
1727 pub fn load(&mut self) -> Result<()> {
1733 if self.base.loaded.is_some() {
1734 return Ok(());
1735 }
1736
1737 if self.base.load_strategy == LoadStrategy::Sequential {
1739 return Ok(());
1740 }
1741
1742 tracing::info!(model = %self.base.model_name, "loading Qwen-Image model components...");
1743
1744 let text_tokenizer_path = self.validate_paths()?;
1745 let transformer_ref = effective_device_ref(
1746 self.pending_placement.as_ref(),
1747 |adv| Some(adv.transformer),
1748 false,
1749 );
1750 let device = crate::device::resolve_device(Some(transformer_ref), || {
1751 crate::device::create_device(self.base.gpu_ordinal, &self.base.progress)
1752 })?;
1753 let transformer_cfg = self.transformer_config();
1754 let transformer_is_quantized = self.detect_is_quantized();
1755 let dtype = crate::engine::gpu_dtype(&device);
1759
1760 let xformer_paths = self.transformer_paths();
1762 let xformer_label = if transformer_is_quantized {
1763 "Loading Qwen-Image transformer (quantized)".to_string()
1764 } else {
1765 format!(
1766 "Loading Qwen-Image transformer ({} shards)",
1767 xformer_paths.len()
1768 )
1769 };
1770 self.base.progress.stage_start(&xformer_label);
1771 let xformer_start = Instant::now();
1772 let transformer = self.load_transformer(
1773 &device,
1774 dtype,
1775 &transformer_cfg,
1776 QWEN_NATIVE_WIDTH,
1777 QWEN_NATIVE_HEIGHT,
1778 )?;
1779 self.base
1780 .progress
1781 .stage_done(&xformer_label, xformer_start.elapsed());
1782 tracing::info!("Qwen-Image transformer loaded");
1783
1784 let free_raw = free_vram_bytes(self.base.gpu_ordinal).unwrap_or(0);
1787 let free = usable_free_vram_bytes(self.base.gpu_ordinal).unwrap_or(0);
1788 let is_cuda = device.is_cuda();
1789 let is_metal = device.is_metal();
1790 if free_raw > 0 {
1791 self.base.progress.info(&format!(
1792 "Free VRAM after transformer: {}",
1793 fmt_gb(free_raw)
1794 ));
1795 }
1796
1797 let vae_on_gpu = should_use_gpu(is_cuda, is_metal, free, VAE_DECODE_VRAM_THRESHOLD);
1798 let vae_ref =
1799 effective_device_ref(self.pending_placement.as_ref(), |adv| Some(adv.vae), false);
1800 let vae_device = crate::device::resolve_device(Some(vae_ref), || {
1801 Ok(if vae_on_gpu {
1802 device.clone()
1803 } else {
1804 Device::Cpu
1805 })
1806 })?;
1807 let vae_on_gpu = !vae_device.is_cpu();
1808 let vae_dtype = DType::F32;
1811 let vae_device_label = if vae_on_gpu { "GPU" } else { "CPU" };
1812
1813 let vae_label = format!("Loading Qwen-Image VAE ({}, F32)", vae_device_label);
1815 self.base.progress.stage_start(&vae_label);
1816 let vae_start = Instant::now();
1817 let vae = self.load_vae(&vae_device, vae_dtype)?;
1818 self.base
1819 .progress
1820 .stage_done(&vae_label, vae_start.elapsed());
1821
1822 let resolved_text_encoder =
1824 self.resolve_text_encoder_source(&device, free, Qwen2TextEncoderUsage::Resident)?;
1825 let (te_plan, te_auto_device_label) =
1826 self.resolve_text_encoder_plan(&device, &resolved_text_encoder, free);
1827 let qwen_ref = effective_device_ref(self.pending_placement.as_ref(), |adv| adv.qwen, true);
1828 let auto_te_device = if te_plan.use_gpu {
1829 device.clone()
1830 } else {
1831 Device::Cpu
1832 };
1833 let te_device =
1834 crate::device::resolve_device(Some(qwen_ref), || Ok(auto_te_device.clone()))?;
1835 let te_use_gpu = !te_device.is_cpu();
1836 let te_device_label: String = if te_use_gpu == te_plan.use_gpu {
1837 te_auto_device_label
1838 } else if te_use_gpu {
1839 "GPU".into()
1840 } else {
1841 "CPU".into()
1842 };
1843 let te_dtype = Self::text_encoder_load_dtype(te_use_gpu, dtype);
1844
1845 let preload_text_encoder = self.should_preload_text_encoder();
1846 let te_label = if resolved_text_encoder.is_gguf {
1847 if preload_text_encoder {
1848 format!(
1849 "Loading Qwen2.5 text encoder ({} GGUF, {})",
1850 resolved_text_encoder.variant_label, te_device_label
1851 )
1852 } else {
1853 format!(
1854 "Preparing Qwen2.5 text encoder ({} GGUF, {})",
1855 resolved_text_encoder.variant_label, te_device_label
1856 )
1857 }
1858 } else if preload_text_encoder {
1859 format!(
1860 "Loading Qwen2.5 text encoder ({} shards, {})",
1861 resolved_text_encoder.paths.len(),
1862 te_device_label,
1863 )
1864 } else {
1865 format!(
1866 "Preparing Qwen2.5 text encoder ({} shards, {})",
1867 resolved_text_encoder.paths.len(),
1868 te_device_label,
1869 )
1870 };
1871 self.base.progress.stage_start(&te_label);
1872 let te_start = Instant::now();
1873 let text_tokenizer = self.load_text_tokenizer(&text_tokenizer_path)?;
1874 let text_encoder = self.load_text_encoder(
1875 &resolved_text_encoder,
1876 &text_tokenizer_path,
1877 text_tokenizer,
1878 &te_device,
1879 te_dtype,
1880 preload_text_encoder,
1881 )?;
1882 self.base.progress.stage_done(&te_label, te_start.elapsed());
1883 if preload_text_encoder {
1884 tracing::info!(device = %te_device_label, "Qwen2.5 text encoder loaded");
1885 } else {
1886 tracing::info!(device = %te_device_label, "Qwen2.5 text encoder prepared for staged loading");
1887 }
1888
1889 self.base.loaded = Some(LoadedQwenImage {
1890 transformer: Some(transformer),
1891 text_encoder,
1892 vae,
1893 vae_path: self.base.paths.vae.clone(),
1894 transformer_cfg,
1895 device,
1896 vae_device,
1897 dtype,
1898 });
1899
1900 tracing::info!(model = %self.base.model_name, "all Qwen-Image components loaded");
1901 Ok(())
1902 }
1903
1904 fn reload_transformer(
1906 &self,
1907 loaded: &mut LoadedQwenImage,
1908 width: usize,
1909 height: usize,
1910 ) -> Result<()> {
1911 let transformer = self.load_transformer(
1912 &loaded.device,
1913 loaded.dtype,
1914 &loaded.transformer_cfg,
1915 width,
1916 height,
1917 )?;
1918 loaded.transformer = Some(transformer);
1919 Ok(())
1920 }
1921
1922 fn generate_sequential(&mut self, req: &GenerateRequest) -> Result<GenerateResponse> {
1924 let text_tokenizer_path = self.validate_paths()?;
1925 let transformer_cfg = self.transformer_config();
1926
1927 let transformer_ref = effective_device_ref(
1928 self.pending_placement.as_ref(),
1929 |adv| Some(adv.transformer),
1930 false,
1931 );
1932 let device = crate::device::resolve_device(Some(transformer_ref), || {
1933 crate::device::create_device(self.base.gpu_ordinal, &self.base.progress)
1934 })?;
1935 let dtype = crate::engine::gpu_dtype(&device);
1936 let transformer_is_quantized = self.detect_is_quantized();
1937
1938 let start = Instant::now();
1939 let seed = req.seed.unwrap_or_else(rand_seed);
1940
1941 let width = req.width as usize;
1942 let height = req.height as usize;
1943 let free = usable_free_vram_bytes(self.base.gpu_ordinal).unwrap_or(0);
1946 let resolved_text_encoder =
1947 self.resolve_text_encoder_source(&device, free, Qwen2TextEncoderUsage::Sequential)?;
1948 let (plan, _device_label) =
1949 self.resolve_text_encoder_plan(&device, &resolved_text_encoder, free);
1950 let use_cpu_staging = plan.use_cpu_staging;
1951
1952 tracing::info!(
1953 prompt = %req.prompt,
1954 seed, width, height,
1955 steps = req.steps,
1956 "starting sequential Qwen-Image generation"
1957 );
1958
1959 self.base
1960 .progress
1961 .info("Using sequential loading (load-use-drop) to minimize peak memory");
1962
1963 let use_cfg = req.guidance > 1.0;
1965 let prompt_key = prompt_text_key(&req.prompt);
1966 let uncond_key = prompt_text_key(QWEN_EMPTY_NEGATIVE_PROMPT);
1967 let (prompt_cached, uncond_cached) = {
1968 let mut cache = self.prompt_cache.lock().expect("cache poisoned");
1969 let prompt_cached = cache.get_cloned(&prompt_key);
1970 let uncond_cached = if use_cfg {
1971 cache.get_cloned(&uncond_key)
1972 } else {
1973 None
1974 };
1975 (prompt_cached, uncond_cached)
1976 };
1977 let both_cached = prompt_cached.is_some() && (!use_cfg || uncond_cached.is_some());
1978
1979 let (mut encoder_hidden_states, mut encoder_attention_mask, mut uncond_hs, mut uncond_mask) =
1980 if both_cached {
1981 self.base.progress.cache_hit("prompt conditioning");
1982 let cached = prompt_cached.unwrap();
1983 let restore_device = if use_cpu_staging {
1984 &Device::Cpu
1985 } else {
1986 &device
1987 };
1988 let restore_dtype = if use_cpu_staging { DType::F32 } else { dtype };
1989 let (hs, mask) = cached.restore(restore_device, restore_dtype)?;
1990 let (u_hs, u_mask) = if use_cfg {
1991 let ucached = uncond_cached.unwrap();
1992 let (u_hs, u_mask) = ucached.restore(restore_device, restore_dtype)?;
1993 (Some(u_hs), Some(u_mask))
1994 } else {
1995 (None, None)
1996 };
1997 (hs, mask, u_hs, u_mask)
1998 } else {
1999 let (te_plan, te_auto_device_label) =
2000 self.resolve_text_encoder_plan(&device, &resolved_text_encoder, free);
2001 let qwen_ref =
2002 effective_device_ref(self.pending_placement.as_ref(), |adv| adv.qwen, true);
2003 let auto_te_device = if te_plan.use_gpu {
2004 device.clone()
2005 } else {
2006 Device::Cpu
2007 };
2008 let te_device =
2009 crate::device::resolve_device(Some(qwen_ref), || Ok(auto_te_device.clone()))?;
2010 let te_use_gpu = !te_device.is_cpu();
2011 let te_device_label: String = if te_use_gpu == te_plan.use_gpu {
2012 te_auto_device_label
2013 } else if te_use_gpu {
2014 "GPU".into()
2015 } else {
2016 "CPU".into()
2017 };
2018 let te_dtype = Self::text_encoder_load_dtype(te_use_gpu, dtype);
2019
2020 let te_label = if resolved_text_encoder.is_gguf {
2021 format!(
2022 "Loading Qwen2.5 text encoder ({} GGUF, {})",
2023 resolved_text_encoder.variant_label, te_device_label
2024 )
2025 } else {
2026 format!(
2027 "Loading Qwen2.5 text encoder ({} shards, {})",
2028 resolved_text_encoder.paths.len(),
2029 te_device_label,
2030 )
2031 };
2032 if te_plan.use_cpu_staging && device.is_metal() && !resolved_text_encoder.is_gguf {
2033 self.base.progress.info(
2034 "Skipping hard preflight for Qwen2.5 text encoder on Metal; sequential mode spills prompt conditioning to CPU after encoding",
2035 );
2036 } else {
2037 let te_activation_budget = crate::device::activation_bytes(
2038 req.width,
2039 req.height,
2040 1,
2041 crate::device::dtype_bytes(te_dtype),
2042 crate::device::ActivationFamily::SmallTransformer,
2043 );
2044 preflight_memory_check(
2045 "Qwen2.5 text encoder",
2046 resolved_text_encoder.size_bytes,
2047 te_activation_budget,
2048 )?;
2049 }
2050
2051 if let Some(status) = memory_status_string() {
2052 self.base.progress.info(&status);
2053 }
2054
2055 self.base.progress.stage_start(&te_label);
2056 let te_start = Instant::now();
2057 let text_tokenizer = self.load_text_tokenizer(&text_tokenizer_path)?;
2058 let mut text_encoder = self.load_text_encoder(
2059 &resolved_text_encoder,
2060 &text_tokenizer_path,
2061 text_tokenizer,
2062 &te_device,
2063 te_dtype,
2064 true,
2065 )?;
2066 self.base.progress.stage_done(&te_label, te_start.elapsed());
2067
2068 let (hs, mask) = Self::encode_prompt_cached(
2069 &self.base.progress,
2070 &self.prompt_cache,
2071 &mut text_encoder,
2072 &req.prompt,
2073 &device,
2074 dtype,
2075 )?;
2076 let (hs, mask) = Self::maybe_spill_conditioning(use_cpu_staging, hs, mask)?;
2077
2078 let (u_hs, u_mask) = if use_cfg {
2079 let (hs, mask) = Self::encode_prompt_cached(
2080 &self.base.progress,
2081 &self.prompt_cache,
2082 &mut text_encoder,
2083 QWEN_EMPTY_NEGATIVE_PROMPT,
2084 &device,
2085 dtype,
2086 )?;
2087 let (hs, mask) = Self::maybe_spill_conditioning(use_cpu_staging, hs, mask)?;
2088 (Some(hs), Some(mask))
2089 } else {
2090 (None, None)
2091 };
2092
2093 drop(text_encoder);
2094 device.synchronize()?;
2096 if let Some(status) = crate::device::memory_status_string() {
2097 if use_cpu_staging {
2098 self.base.progress.info(&format!(
2099 "Freed Qwen2.5 text encoder and spilled prompt conditioning to CPU — {status}"
2100 ));
2101 } else {
2102 self.base
2103 .progress
2104 .info(&format!("Freed Qwen2.5 text encoder — {status}"));
2105 }
2106 } else {
2107 if use_cpu_staging {
2108 self.base.progress.info(
2109 "Freed Qwen2.5 text encoder and spilled prompt conditioning to CPU",
2110 );
2111 } else {
2112 self.base.progress.info("Freed Qwen2.5 text encoder");
2113 }
2114 }
2115
2116 (hs, mask, u_hs, u_mask)
2117 };
2118
2119 if use_cfg {
2120 let ((cond_hs, cond_mask), (neg_hs, neg_mask)) = align_cfg_conditioning(
2121 &encoder_hidden_states,
2122 &encoder_attention_mask,
2123 uncond_hs.as_ref().expect("unconditional prompt missing"),
2124 uncond_mask.as_ref().expect("unconditional mask missing"),
2125 )?;
2126 encoder_hidden_states = cond_hs;
2127 encoder_attention_mask = cond_mask;
2128 uncond_hs = Some(neg_hs);
2129 uncond_mask = Some(neg_mask);
2130 }
2131
2132 let xformer_paths = self.transformer_paths();
2134 let xformer_size: u64 = xformer_paths
2135 .iter()
2136 .filter_map(|p| std::fs::metadata(p).ok())
2137 .map(|m| m.len())
2138 .sum();
2139 let xformer_activation_budget = crate::device::activation_bytes(
2140 req.width,
2141 req.height,
2142 if req.guidance > 1.0 { 2 } else { 1 },
2143 crate::device::dtype_bytes(dtype),
2144 crate::device::ActivationFamily::QwenImageDit,
2145 );
2146 preflight_memory_check(
2147 "Qwen-Image transformer",
2148 xformer_size,
2149 xformer_activation_budget,
2150 )?;
2151
2152 if let Some(status) = memory_status_string() {
2153 self.base.progress.info(&status);
2154 }
2155
2156 let xformer_label = if transformer_is_quantized {
2157 "Loading Qwen-Image transformer (quantized)".to_string()
2158 } else {
2159 format!(
2160 "Loading Qwen-Image transformer ({} shards)",
2161 xformer_paths.len()
2162 )
2163 };
2164 self.base.progress.stage_start(&xformer_label);
2165 let xformer_start = Instant::now();
2166 let transformer = self.load_transformer(&device, dtype, &transformer_cfg, width, height)?;
2167 self.base
2168 .progress
2169 .stage_done(&xformer_label, xformer_start.elapsed());
2170
2171 if use_cpu_staging {
2172 encoder_hidden_states = encoder_hidden_states.to_device(&device)?.to_dtype(dtype)?;
2173 encoder_attention_mask = encoder_attention_mask.to_device(&device)?;
2174 if let Some(hs) = uncond_hs.take() {
2175 uncond_hs = Some(hs.to_device(&device)?.to_dtype(dtype)?);
2176 }
2177 if let Some(mask) = uncond_mask.take() {
2178 uncond_mask = Some(mask.to_device(&device)?);
2179 }
2180 if let Some(status) = memory_status_string() {
2181 self.base.progress.info(&format!(
2182 "Restored prompt conditioning to GPU for denoising — {status}"
2183 ));
2184 } else {
2185 self.base
2186 .progress
2187 .info("Restored prompt conditioning to GPU for denoising");
2188 }
2189 }
2190
2191 let vae_downsample = 8;
2193 let latent_h = height / vae_downsample;
2194 let latent_w = width / vae_downsample;
2195 let is_img2img = req.source_image.is_some();
2196
2197 let (prepared_img2img_latents, inpaint_ctx) = if let Some(ref source_bytes) =
2199 req.source_image
2200 {
2201 let free_for_encode = usable_free_vram_bytes(self.base.gpu_ordinal).unwrap_or(0);
2203 let encode_on_gpu = should_use_gpu(
2204 device.is_cuda(),
2205 device.is_metal(),
2206 free_for_encode,
2207 VAE_DECODE_VRAM_THRESHOLD,
2208 );
2209 let encode_device = if encode_on_gpu {
2210 device.clone()
2211 } else {
2212 Device::Cpu
2213 };
2214 let encode_label = if encode_on_gpu { "GPU" } else { "CPU" };
2215
2216 let vae_label = format!("Loading Qwen-Image VAE ({}, F32) for encode", encode_label);
2217 self.base.progress.stage_start(&vae_label);
2218 let vae_start = Instant::now();
2219 let encode_vae = self.load_vae(&encode_device, DType::F32)?;
2220 self.base
2221 .progress
2222 .stage_done(&vae_label, vae_start.elapsed());
2223
2224 let encoded = Self::encode_vae_with_fallback(
2225 source_bytes,
2226 req.width,
2227 req.height,
2228 &encode_vae,
2229 &encode_device,
2230 &device,
2231 &self.base.progress,
2232 || self.load_vae(&Device::Cpu, DType::F32),
2233 )?;
2234 let encoded = encoded.to_device(&device)?.to_dtype(dtype)?;
2235 let start_sigma = QwenImageScheduler::new_img2img(
2236 req.steps as usize,
2237 image_seq_len(latent_h, latent_w, transformer_cfg.patch_size),
2238 req.strength,
2239 )
2240 .0
2241 .initial_sigma();
2242 let prepared = crate::img2img::prepare_flow_match_img2img(
2243 &encoded,
2244 seed,
2245 &[1, 16, latent_h, latent_w],
2246 start_sigma,
2247 req.mask_image.as_deref(),
2248 latent_h,
2249 latent_w,
2250 &device,
2251 dtype,
2252 )?;
2253
2254 drop(encode_vae);
2256 device.synchronize()?;
2257
2258 tracing::info!(
2259 strength = req.strength,
2260 "img2img: encoded source image to latents"
2261 );
2262
2263 (Some(prepared.initial_latents), prepared.inpaint_ctx)
2264 } else {
2265 (None, None)
2266 };
2267
2268 let image_seq_len = image_seq_len(latent_h, latent_w, transformer_cfg.patch_size);
2269 let (mut scheduler, num_steps) = if is_img2img {
2270 QwenImageScheduler::new_img2img(req.steps as usize, image_seq_len, req.strength)
2271 } else {
2272 let sched = QwenImageScheduler::new(req.steps as usize, image_seq_len);
2273 let n = sched.num_steps();
2274 (sched, n)
2275 };
2276
2277 let mut latents = if let Some(initial) = &prepared_img2img_latents {
2279 initial.clone()
2280 } else {
2281 let noise =
2282 crate::engine::seeded_randn(seed, &[1, 16, latent_h, latent_w], &device, dtype)?;
2283 (noise * scheduler.initial_sigma())?
2284 };
2285
2286 let denoise_label = format!("Denoising ({} steps)", num_steps);
2287 self.base.progress.stage_start(&denoise_label);
2288 let denoise_start = Instant::now();
2289
2290 if std::env::var_os("MOLD_QWEN_DEBUG").is_some() {
2291 eprintln!(
2292 "[qwen-debug] cfg={} guidance={:.1} image_seq_len={} sigmas[0]={:.4} sigmas[last]={:.4} img2img={}",
2293 use_cfg,
2294 req.guidance,
2295 image_seq_len,
2296 scheduler.sigmas[0],
2297 scheduler.sigmas[scheduler.sigmas.len() - 1],
2298 is_img2img,
2299 );
2300 }
2301
2302 let use_batched_cfg = use_cfg && transformer.supports_cfg_batching();
2303 if use_cfg && !use_batched_cfg {
2304 self.base.progress.info(
2305 "Low-memory quantized Qwen CUDA path detected — disabling CFG batching to reduce peak CUDA memory",
2306 );
2307 }
2308
2309 let (batched_hs, batched_mask) = if use_batched_cfg {
2312 let hs = Tensor::cat(&[&encoder_hidden_states, uncond_hs.as_ref().unwrap()], 0)?;
2313 let mask = Tensor::cat(&[&encoder_attention_mask, uncond_mask.as_ref().unwrap()], 0)?;
2314 (hs, mask)
2315 } else {
2316 (
2317 encoder_hidden_states.clone(),
2318 encoder_attention_mask.clone(),
2319 )
2320 };
2321
2322 for step in 0..num_steps {
2323 let step_start = Instant::now();
2324 let t = scheduler.current_timestep();
2325 let noise_pred = if use_cfg {
2326 let (cond_pred, uncond_pred) = if use_batched_cfg {
2327 let t_tensor =
2328 Tensor::from_vec(vec![t as f32; 2], (2,), &device)?.to_dtype(dtype)?;
2329 let batched_latents = Tensor::cat(&[&latents, &latents], 0)?;
2330 let batched_pred = transformer.forward(
2331 &batched_latents,
2332 &t_tensor,
2333 &batched_hs,
2334 &batched_mask,
2335 )?;
2336 (batched_pred.narrow(0, 0, 1)?, batched_pred.narrow(0, 1, 1)?)
2337 } else {
2338 let t_tensor =
2339 Tensor::from_vec(vec![t as f32], (1,), &device)?.to_dtype(dtype)?;
2340 (
2341 transformer.forward(
2342 &latents,
2343 &t_tensor,
2344 &encoder_hidden_states,
2345 &encoder_attention_mask,
2346 )?,
2347 transformer.forward(
2348 &latents,
2349 &t_tensor,
2350 uncond_hs.as_ref().unwrap(),
2351 uncond_mask.as_ref().unwrap(),
2352 )?,
2353 )
2354 };
2355 if step == 0 {
2356 Self::debug_tensor_stats("cond_pred[0]", &cond_pred);
2357 Self::debug_tensor_stats("uncond_pred[0]", &uncond_pred);
2358 }
2359 let cond_f32 = cond_pred.to_dtype(DType::F32)?;
2362 let uncond_f32 = uncond_pred.to_dtype(DType::F32)?;
2363 let comb = (&uncond_f32 + ((&cond_f32 - &uncond_f32)? * req.guidance)?)?;
2364 let cond_norm = cond_f32.sqr()?.sum_keepdim(1)?.sqrt()?;
2365 let comb_norm = comb.sqr()?.sum_keepdim(1)?.sqrt()?.clamp(1e-8, f64::MAX)?;
2366 let rescaled = comb.broadcast_mul(&(cond_norm / comb_norm)?)?;
2367 rescaled.to_dtype(dtype)?
2368 } else {
2369 let t_tensor = Tensor::from_vec(vec![t as f32], (1,), &device)?.to_dtype(dtype)?;
2370 transformer.forward(
2371 &latents,
2372 &t_tensor,
2373 &encoder_hidden_states,
2374 &encoder_attention_mask,
2375 )?
2376 };
2377 if step == 0 || step == num_steps / 2 || step == num_steps - 1 {
2378 Self::debug_tensor_stats(&format!("noise_pred[{step}]"), &noise_pred);
2379 Self::debug_tensor_stats(&format!("latents[{step}]"), &latents);
2380 }
2381 if step == 0 {
2382 Self::validate_qwen_tensor_boundary("noise_pred[0]", &noise_pred)?;
2383 }
2384 latents = scheduler.step(&noise_pred, &latents)?;
2385 if step == num_steps - 1 {
2386 Self::validate_qwen_tensor_boundary("latents_final", &latents)?;
2387 }
2388
2389 if let Some(ref ctx) = inpaint_ctx {
2391 latents = crate::img2img::apply_flow_match_inpaint(
2392 &latents,
2393 ctx,
2394 scheduler.sigmas[step + 1],
2395 )?;
2396 }
2397
2398 if std::env::var_os("MOLD_QWEN_DEBUG").is_some() {
2399 let n = latents
2400 .ne(&latents)?
2401 .to_dtype(candle_core::DType::U32)?
2402 .sum_all()?
2403 .to_scalar::<u32>()?;
2404 if n > 0 {
2405 eprintln!(
2406 "[qwen-nan] NaN in latents AFTER step {step}: {n}/{}",
2407 latents.elem_count()
2408 );
2409 }
2410 }
2411 self.base.progress.emit(ProgressEvent::DenoiseStep {
2412 step: step + 1,
2413 total: num_steps,
2414 elapsed: step_start.elapsed(),
2415 });
2416 }
2417
2418 self.base
2419 .progress
2420 .stage_done(&denoise_label, denoise_start.elapsed());
2421
2422 drop(transformer);
2424 drop(encoder_hidden_states);
2425 drop(encoder_attention_mask);
2426 drop(uncond_hs);
2427 drop(uncond_mask);
2428 device.synchronize()?;
2429 self.base.progress.info("Freed Qwen-Image transformer");
2430
2431 if let Some(status) = memory_status_string() {
2433 self.base.progress.info(&status);
2434 }
2435
2436 let free_for_vae = usable_free_vram_bytes(self.base.gpu_ordinal).unwrap_or(0);
2438 let vae_on_gpu = should_use_gpu(
2439 device.is_cuda(),
2440 device.is_metal(),
2441 free_for_vae,
2442 VAE_DECODE_VRAM_THRESHOLD,
2443 );
2444 let vae_ref =
2445 effective_device_ref(self.pending_placement.as_ref(), |adv| Some(adv.vae), false);
2446 let vae_device = crate::device::resolve_device(Some(vae_ref), || {
2447 Ok(if vae_on_gpu {
2448 device.clone()
2449 } else {
2450 Device::Cpu
2451 })
2452 })?;
2453 let vae_on_gpu = !vae_device.is_cpu();
2454 let vae_dtype = DType::F32;
2457 let vae_device_label = if vae_on_gpu { "GPU" } else { "CPU" };
2458
2459 let vae_label = format!("Loading Qwen-Image VAE ({}, F32)", vae_device_label);
2460 self.base.progress.stage_start(&vae_label);
2461 let vae_start = Instant::now();
2462 let vae = self.load_vae(&vae_device, vae_dtype)?;
2463 self.base
2464 .progress
2465 .stage_done(&vae_label, vae_start.elapsed());
2466
2467 self.base.progress.stage_start("VAE decode");
2468 let vae_decode_start = Instant::now();
2469 let free_for_decode = usable_free_vram_bytes(self.base.gpu_ordinal).unwrap_or(0);
2470 let prefer_tiled = Self::should_proactively_tile_vae_decode(
2471 req.width,
2472 req.height,
2473 vae_device.is_cuda(),
2474 free_for_decode,
2475 );
2476
2477 let image = Self::decode_vae_with_fallback(
2478 &latents,
2479 &vae,
2480 &vae_device,
2481 &device,
2482 &self.base.progress,
2483 prefer_tiled,
2484 || self.load_vae(&Device::Cpu, DType::F32),
2485 )?;
2486 Self::validate_qwen_tensor_boundary("image_pre_postprocess", &image)?;
2487 Self::debug_tensor_stats("image_pre_postprocess", &image);
2488 let image = postprocess_image(&image)?;
2489 let post_stats = Self::validate_qwen_tensor_boundary("image_postprocess", &image)?;
2490 Self::debug_tensor_stats("image_postprocess", &image);
2491 let image = image.i(0)?;
2492 if Self::near_black_image_stats(post_stats) {
2493 self.base.progress.info(
2494 "Qwen diagnostic: decoded image is near-black after VAE postprocess; inspect MOLD_QWEN_DEBUG tensor stats to separate denoise math from VAE decode",
2495 );
2496 tracing::warn!(
2497 min = post_stats.min,
2498 max = post_stats.max,
2499 mean = post_stats.mean,
2500 "Qwen decoded image is near-black after VAE postprocess"
2501 );
2502 }
2503
2504 self.base
2505 .progress
2506 .stage_done("VAE decode", vae_decode_start.elapsed());
2507
2508 let output_metadata = build_output_metadata(req, seed, None);
2509 let image_bytes = encode_image(
2510 &image,
2511 req.resolved_output_format(),
2512 req.width,
2513 req.height,
2514 output_metadata.as_ref(),
2515 )?;
2516
2517 let generation_time_ms = start.elapsed().as_millis() as u64;
2518 tracing::info!(
2519 generation_time_ms,
2520 seed,
2521 "sequential Qwen-Image generation complete"
2522 );
2523
2524 Ok(GenerateResponse {
2525 images: vec![ImageData {
2526 data: image_bytes,
2527 format: req.resolved_output_format(),
2528 width: req.width,
2529 height: req.height,
2530 index: 0,
2531 }],
2532 generation_time_ms,
2533 model: req.model.clone(),
2534 seed_used: seed,
2535 video: None,
2536 gpu: None,
2537 })
2538 }
2539
2540 fn generate_edit_loaded(&mut self, req: &GenerateRequest) -> Result<GenerateResponse> {
2541 let progress = &self.base.progress;
2542 let start = Instant::now();
2543
2544 let loaded_ref = self
2545 .base
2546 .loaded
2547 .as_ref()
2548 .ok_or_else(|| anyhow::anyhow!("model not loaded"))?;
2549 let needs_reload = loaded_ref.transformer.is_none();
2550 if needs_reload {
2551 let mut loaded_mut = self
2552 .base
2553 .loaded
2554 .take()
2555 .ok_or_else(|| anyhow::anyhow!("model not loaded"))?;
2556 progress.stage_start("Reloading Qwen-Image transformer");
2557 let reload_start = Instant::now();
2558 self.reload_transformer(&mut loaded_mut, req.width as usize, req.height as usize)?;
2559 progress.stage_done("Reloading Qwen-Image transformer", reload_start.elapsed());
2560 self.base.loaded = Some(loaded_mut);
2561 }
2562
2563 let is_edit_family = self.is_edit_family();
2564 let loaded = self
2565 .base
2566 .loaded
2567 .as_mut()
2568 .ok_or_else(|| anyhow::anyhow!("model not loaded"))?;
2569 let seed = req.seed.unwrap_or_else(rand_seed);
2570 let width = req.width as usize;
2571 let height = req.height as usize;
2572 let edit_images = req
2573 .edit_images
2574 .as_ref()
2575 .ok_or_else(|| anyhow::anyhow!("qwen-image-edit requires edit_images"))?;
2576 let use_cfg = req.guidance > 1.0;
2577 let negative_prompt = req
2578 .negative_prompt
2579 .as_deref()
2580 .unwrap_or(QWEN_EMPTY_NEGATIVE_PROMPT);
2581 let formatted_prompt = Self::qwen_image_edit_prompt(&req.prompt, edit_images.len());
2582 let formatted_negative = Self::qwen_image_edit_prompt(negative_prompt, edit_images.len());
2583
2584 tracing::info!(
2585 prompt = %req.prompt,
2586 seed,
2587 width,
2588 height,
2589 steps = req.steps,
2590 edit_images = edit_images.len(),
2591 "starting Qwen-Image edit generation"
2592 );
2593
2594 if loaded.text_encoder.model.is_none() {
2595 let label = if loaded.text_encoder.is_parked() {
2596 "Unparking Qwen2.5 encoder (CPU→GPU)"
2597 } else {
2598 "Reloading Qwen2.5 encoder"
2599 };
2600 progress.stage_start(label);
2601 let reload_start = Instant::now();
2602 if loaded.text_encoder.is_parked() {
2603 loaded.text_encoder.unpark_to_gpu(progress)?;
2604 } else {
2605 loaded.text_encoder.reload(progress)?;
2606 }
2607 progress.stage_done(label, reload_start.elapsed());
2608 }
2609
2610 progress.stage_start("Encoding prompt (Qwen2.5 edit)");
2611 let encode_start = Instant::now();
2612 let (encoder_hidden_states, encoder_attention_mask, _) =
2613 loaded.text_encoder.encode_formatted_multimodal(
2614 &formatted_prompt,
2615 edit_images,
2616 &loaded.device,
2617 loaded.dtype,
2618 )?;
2619 progress.stage_done("Encoding prompt (Qwen2.5 edit)", encode_start.elapsed());
2620 let (encoder_hidden_states, encoder_attention_mask, uncond_hs, uncond_mask) = if use_cfg {
2621 progress.stage_start("Encoding negative prompt (Qwen2.5 edit)");
2622 let neg_start = Instant::now();
2623 let (hs, mask, _) = loaded.text_encoder.encode_formatted_multimodal(
2624 &formatted_negative,
2625 edit_images,
2626 &loaded.device,
2627 loaded.dtype,
2628 )?;
2629 progress.stage_done(
2630 "Encoding negative prompt (Qwen2.5 edit)",
2631 neg_start.elapsed(),
2632 );
2633 let ((cond_hs, cond_mask), (neg_hs, neg_mask)) = align_cfg_conditioning(
2634 &encoder_hidden_states,
2635 &encoder_attention_mask,
2636 &hs,
2637 &mask,
2638 )?;
2639 (cond_hs, cond_mask, Some(neg_hs), Some(neg_mask))
2640 } else {
2641 (encoder_hidden_states, encoder_attention_mask, None, None)
2642 };
2643
2644 let drop_text_encoder = is_edit_family || loaded.text_encoder.on_gpu;
2645 if drop_text_encoder {
2646 let park_mode = crate::device::keep_te_in_ram()
2647 && !loaded.device.is_metal()
2648 && !loaded.text_encoder.is_quantized;
2649 if park_mode {
2650 loaded.text_encoder.park_to_cpu()?;
2651 tracing::info!(
2652 on_gpu = loaded.text_encoder.on_gpu,
2653 "Qwen2.5 text encoder parked to CPU host RAM after edit conditioning"
2654 );
2655 } else {
2656 loaded.text_encoder.drop_weights();
2657 tracing::info!(
2658 on_gpu = loaded.text_encoder.on_gpu,
2659 "Qwen2.5 text encoder dropped after edit conditioning"
2660 );
2661 }
2662 }
2663
2664 let mut packed_input_storage = Vec::with_capacity(edit_images.len());
2665 let mut img_shapes = vec![(1usize, height / 16, width / 16)];
2666 progress.stage_start("Encoding edit images (VAE)");
2667 let encode_start = Instant::now();
2668 for image_bytes in edit_images {
2669 let (vae_width, vae_height) =
2670 Self::qwen_image_edit_image_dims(image_bytes, QWEN_IMAGE_EDIT_VAE_AREA)?;
2671 let encoded = Self::encode_vae_with_fallback(
2672 image_bytes,
2673 vae_width,
2674 vae_height,
2675 &loaded.vae,
2676 &loaded.vae_device,
2677 &loaded.device,
2678 progress,
2679 || {
2680 Ok(QwenImageVae::load(
2681 &loaded.vae_path,
2682 &Device::Cpu,
2683 DType::F32,
2684 progress,
2685 )?)
2686 },
2687 )?
2688 .to_device(&loaded.device)?
2689 .to_dtype(loaded.dtype)?;
2690 img_shapes.push((1, encoded.dim(2)? / 2, encoded.dim(3)? / 2));
2691 packed_input_storage.push(Self::pack_latents_4d(&encoded)?);
2692 }
2693 progress.stage_done("Encoding edit images (VAE)", encode_start.elapsed());
2694
2695 let packed_inputs = if packed_input_storage.is_empty() {
2696 None
2697 } else {
2698 let tensors = packed_input_storage.iter().collect::<Vec<_>>();
2699 Some(Tensor::cat(&tensors, 1)?)
2700 };
2701
2702 let noise = crate::engine::seeded_randn(
2703 seed,
2704 &[1, 16, height / 8, width / 8],
2705 &loaded.device,
2706 loaded.dtype,
2707 )?;
2708 let mut scheduler =
2709 QwenImageScheduler::new(req.steps as usize, (height / 16) * (width / 16));
2710 let num_steps = scheduler.num_steps();
2711 let mut latents = Self::pack_latents_4d(&(noise * scheduler.initial_sigma())?)?;
2712 let output_seq_len = latents.dim(1)?;
2713
2714 let denoise_label = format!("Denoising edit ({} steps)", num_steps);
2715 progress.stage_start(&denoise_label);
2716 let denoise_start = Instant::now();
2717
2718 {
2719 let transformer = loaded
2720 .transformer
2721 .as_ref()
2722 .expect("transformer must be loaded for denoising");
2723 let use_batched_cfg = use_cfg && transformer.supports_cfg_batching();
2724 let (batched_hs, batched_mask) = if use_batched_cfg {
2725 let hs = Tensor::cat(&[&encoder_hidden_states, uncond_hs.as_ref().unwrap()], 0)?;
2726 let mask =
2727 Tensor::cat(&[&encoder_attention_mask, uncond_mask.as_ref().unwrap()], 0)?;
2728 (hs, mask)
2729 } else {
2730 (
2731 encoder_hidden_states.clone(),
2732 encoder_attention_mask.clone(),
2733 )
2734 };
2735
2736 for step in 0..num_steps {
2737 let step_start = Instant::now();
2738 let t = scheduler.current_timestep();
2739 let timestep = if use_batched_cfg {
2740 Tensor::from_vec(vec![t as f32; 2], (2,), &loaded.device)?
2741 .to_dtype(loaded.dtype)?
2742 } else {
2743 Tensor::from_vec(vec![t as f32], (1,), &loaded.device)?
2744 .to_dtype(loaded.dtype)?
2745 };
2746
2747 let latent_model_input = if let Some(ref packed_inputs) = packed_inputs {
2748 Tensor::cat(&[&latents, packed_inputs], 1)?
2749 } else {
2750 latents.clone()
2751 };
2752
2753 let noise_pred = if use_cfg {
2754 let (cond_pred, uncond_pred) = if use_batched_cfg {
2755 let batched_input =
2756 Tensor::cat(&[&latent_model_input, &latent_model_input], 0)?;
2757 let pred = transformer.forward_packed(
2758 &batched_input,
2759 ×tep,
2760 &batched_hs,
2761 &batched_mask,
2762 &img_shapes,
2763 )?;
2764 (
2765 pred.narrow(0, 0, 1)?.narrow(1, 0, output_seq_len)?,
2766 pred.narrow(0, 1, 1)?.narrow(1, 0, output_seq_len)?,
2767 )
2768 } else {
2769 (
2770 transformer
2771 .forward_packed(
2772 &latent_model_input,
2773 ×tep,
2774 &encoder_hidden_states,
2775 &encoder_attention_mask,
2776 &img_shapes,
2777 )?
2778 .narrow(1, 0, output_seq_len)?,
2779 transformer
2780 .forward_packed(
2781 &latent_model_input,
2782 ×tep,
2783 uncond_hs.as_ref().unwrap(),
2784 uncond_mask.as_ref().unwrap(),
2785 &img_shapes,
2786 )?
2787 .narrow(1, 0, output_seq_len)?,
2788 )
2789 };
2790
2791 let cond_f32 = cond_pred.to_dtype(DType::F32)?;
2792 let uncond_f32 = uncond_pred.to_dtype(DType::F32)?;
2793 let comb = (&uncond_f32 + ((&cond_f32 - &uncond_f32)? * req.guidance)?)?;
2794 let cond_norm = cond_f32.sqr()?.sum_keepdim(D::Minus1)?.sqrt()?;
2795 let comb_norm = comb
2796 .sqr()?
2797 .sum_keepdim(D::Minus1)?
2798 .sqrt()?
2799 .clamp(1e-8, f64::MAX)?;
2800 comb.broadcast_mul(&(cond_norm / comb_norm)?)?
2801 .to_dtype(loaded.dtype)?
2802 } else {
2803 transformer
2804 .forward_packed(
2805 &latent_model_input,
2806 ×tep,
2807 &encoder_hidden_states,
2808 &encoder_attention_mask,
2809 &img_shapes,
2810 )?
2811 .narrow(1, 0, output_seq_len)?
2812 };
2813
2814 latents = scheduler.step(&noise_pred, &latents)?;
2815 progress.emit(ProgressEvent::DenoiseStep {
2816 step: step + 1,
2817 total: num_steps,
2818 elapsed: step_start.elapsed(),
2819 });
2820 }
2821 }
2822
2823 progress.stage_done(&denoise_label, denoise_start.elapsed());
2824
2825 let latents = Self::unpack_latents_packed(&latents, height / 8, width / 8)?;
2826 let free_for_decode = usable_free_vram_bytes(self.base.gpu_ordinal).unwrap_or(0);
2827 let prefer_tiled = Self::should_proactively_tile_vae_decode(
2828 req.width,
2829 req.height,
2830 loaded.vae_device.is_cuda(),
2831 free_for_decode,
2832 );
2833 let image = Self::decode_vae_with_fallback(
2834 &latents,
2835 &loaded.vae,
2836 &loaded.vae_device,
2837 &loaded.device,
2838 progress,
2839 prefer_tiled,
2840 || {
2841 Ok(QwenImageVae::load(
2842 &loaded.vae_path,
2843 &Device::Cpu,
2844 DType::F32,
2845 progress,
2846 )?)
2847 },
2848 )?;
2849 let image = postprocess_image(&image)?.i(0)?;
2850 let output_metadata = build_output_metadata(req, seed, None);
2851 let image_bytes = encode_image(
2852 &image,
2853 req.resolved_output_format(),
2854 req.width,
2855 req.height,
2856 output_metadata.as_ref(),
2857 )?;
2858
2859 Ok(GenerateResponse {
2860 images: vec![ImageData {
2861 data: image_bytes,
2862 format: req.resolved_output_format(),
2863 width: req.width,
2864 height: req.height,
2865 index: 0,
2866 }],
2867 generation_time_ms: start.elapsed().as_millis() as u64,
2868 model: req.model.clone(),
2869 seed_used: seed,
2870 video: None,
2871 gpu: None,
2872 })
2873 }
2874}
2875
2876impl QwenImageEngine {
2877 fn generate_inner(&mut self, req: &GenerateRequest) -> Result<GenerateResponse> {
2878 if req.scheduler.is_some() {
2879 tracing::warn!(
2880 "scheduler selection not supported for Qwen-Image (flow-matching), ignoring"
2881 );
2882 }
2883
2884 if self.is_edit_family() {
2885 let sequential = self.base.load_strategy == LoadStrategy::Sequential;
2886 if sequential && self.base.loaded.is_none() {
2887 let original = self.base.load_strategy;
2888 self.base.load_strategy = LoadStrategy::Eager;
2889 let load_result = self.load();
2890 self.base.load_strategy = original;
2891 load_result?;
2892 }
2893 if self.base.loaded.is_none() {
2894 bail!("model not loaded -- call load() first");
2895 }
2896 let result = self.generate_edit_loaded(req);
2897 if sequential {
2898 self.unload();
2899 }
2900 return result;
2901 }
2902
2903 if self.base.load_strategy == LoadStrategy::Sequential {
2905 return self.generate_sequential(req);
2906 }
2907
2908 if self.base.loaded.is_none() {
2910 bail!("model not loaded -- call load() first");
2911 }
2912
2913 let progress = &self.base.progress;
2914 let gpu_ordinal = self.base.gpu_ordinal;
2915 let start = Instant::now();
2916
2917 let loaded_ref = self
2919 .base
2920 .loaded
2921 .as_ref()
2922 .ok_or_else(|| anyhow::anyhow!("model not loaded"))?;
2923 let needs_reload = loaded_ref.transformer.is_none();
2924 if needs_reload {
2925 let mut loaded_mut = self
2926 .base
2927 .loaded
2928 .take()
2929 .ok_or_else(|| anyhow::anyhow!("model not loaded"))?;
2930 progress.stage_start("Reloading Qwen-Image transformer");
2931 let reload_start = Instant::now();
2932 self.reload_transformer(&mut loaded_mut, req.width as usize, req.height as usize)?;
2933 progress.stage_done("Reloading Qwen-Image transformer", reload_start.elapsed());
2934 self.base.loaded = Some(loaded_mut);
2935 }
2936
2937 let loaded = self
2938 .base
2939 .loaded
2940 .as_mut()
2941 .ok_or_else(|| anyhow::anyhow!("model not loaded"))?;
2942 let seed = req.seed.unwrap_or_else(rand_seed);
2943
2944 let width = req.width as usize;
2945 let height = req.height as usize;
2946
2947 tracing::info!(
2948 prompt = %req.prompt,
2949 seed, width, height,
2950 steps = req.steps,
2951 "starting Qwen-Image generation"
2952 );
2953
2954 let use_cfg = req.guidance > 1.0;
2955 let prompt_key = prompt_text_key(&req.prompt);
2956 let uncond_key = prompt_text_key(QWEN_EMPTY_NEGATIVE_PROMPT);
2957 let prompt_cached = self
2958 .prompt_cache
2959 .lock()
2960 .expect("cache poisoned")
2961 .get_cloned(&prompt_key);
2962 let uncond_cached = if use_cfg {
2963 self.prompt_cache
2964 .lock()
2965 .expect("cache poisoned")
2966 .get_cloned(&uncond_key)
2967 } else {
2968 None
2969 };
2970 let both_cached = prompt_cached.is_some() && (!use_cfg || uncond_cached.is_some());
2971
2972 let (encoder_hidden_states, encoder_attention_mask, uncond_hs, uncond_mask) = if both_cached
2973 {
2974 let cached = prompt_cached.expect("prompt cache unexpectedly missing");
2975 progress.cache_hit("prompt conditioning");
2976 let (hs, mask) = cached.restore(&loaded.device, loaded.dtype)?;
2977 let (u_hs, u_mask) = if use_cfg {
2978 progress.cache_hit("unconditional conditioning");
2979 let ucached =
2980 uncond_cached.expect("unconditional prompt cache unexpectedly missing");
2981 let (u_hs, u_mask) = ucached.restore(&loaded.device, loaded.dtype)?;
2982 (Some(u_hs), Some(u_mask))
2983 } else {
2984 (None, None)
2985 };
2986 (hs, mask, u_hs, u_mask)
2987 } else {
2988 if loaded.text_encoder.model.is_none() {
2989 let label = if loaded.text_encoder.is_parked() {
2990 "Unparking Qwen2.5 encoder (CPU→GPU)"
2991 } else {
2992 "Reloading Qwen2.5 encoder"
2993 };
2994 progress.stage_start(label);
2995 let reload_start = Instant::now();
2996 if loaded.text_encoder.is_parked() {
2997 loaded.text_encoder.unpark_to_gpu(progress)?;
2998 } else {
2999 loaded.text_encoder.reload(progress)?;
3000 }
3001 progress.stage_done(label, reload_start.elapsed());
3002 }
3003
3004 let (hs, mask) = Self::encode_prompt_cached(
3005 progress,
3006 &self.prompt_cache,
3007 &mut loaded.text_encoder,
3008 &req.prompt,
3009 &loaded.device,
3010 loaded.dtype,
3011 )?;
3012
3013 let (u_hs, u_mask) = if use_cfg {
3014 let (hs, mask) = Self::encode_prompt_cached(
3015 progress,
3016 &self.prompt_cache,
3017 &mut loaded.text_encoder,
3018 QWEN_EMPTY_NEGATIVE_PROMPT,
3019 &loaded.device,
3020 loaded.dtype,
3021 )?;
3022 (Some(hs), Some(mask))
3023 } else {
3024 (None, None)
3025 };
3026
3027 (hs, mask, u_hs, u_mask)
3028 };
3029
3030 let (encoder_hidden_states, encoder_attention_mask, uncond_hs, uncond_mask) = if use_cfg {
3031 let ((cond_hs, cond_mask), (neg_hs, neg_mask)) = align_cfg_conditioning(
3032 &encoder_hidden_states,
3033 &encoder_attention_mask,
3034 uncond_hs.as_ref().expect("unconditional prompt missing"),
3035 uncond_mask.as_ref().expect("unconditional mask missing"),
3036 )?;
3037 (cond_hs, cond_mask, Some(neg_hs), Some(neg_mask))
3038 } else {
3039 (
3040 encoder_hidden_states,
3041 encoder_attention_mask,
3042 uncond_hs,
3043 uncond_mask,
3044 )
3045 };
3046
3047 if loaded.text_encoder.on_gpu {
3049 let free_after_encode = usable_free_vram_bytes(gpu_ordinal).unwrap_or(0);
3050 let required_for_residency = Self::qwen2_hot_text_encoder_required_vram(
3051 req.width,
3052 req.height,
3053 if req.guidance > 1.0 { 2 } else { 1 },
3054 loaded.dtype,
3055 );
3056 let action =
3057 Self::qwen2_text_encoder_post_encode_action(Qwen2TextEncoderResidencyInput {
3058 on_gpu: loaded.text_encoder.on_gpu,
3059 is_quantized: loaded.text_encoder.is_quantized,
3060 is_metal: loaded.device.is_metal(),
3061 keep_te_ram: crate::device::keep_te_in_ram(),
3062 prompt_cache_miss: !both_cached,
3063 transformer_resident: loaded.transformer.is_some(),
3064 free_vram_bytes: free_after_encode,
3065 required_vram_bytes: required_for_residency,
3066 });
3067 match action {
3068 Qwen2TextEncoderPostEncodeAction::KeepGpu => {
3069 progress.info(&format!(
3070 "Keeping Qwen2.5 text encoder on GPU for hot prompt-cache misses ({} free >= {} reserve)",
3071 fmt_gb(free_after_encode),
3072 fmt_gb(required_for_residency)
3073 ));
3074 tracing::info!(
3075 free_vram_bytes = free_after_encode,
3076 required_vram_bytes = required_for_residency,
3077 is_quantized = loaded.text_encoder.is_quantized,
3078 "Qwen2.5 text encoder kept on GPU after cache miss"
3079 );
3080 }
3081 Qwen2TextEncoderPostEncodeAction::ParkCpu => {
3082 loaded.text_encoder.park_to_cpu()?;
3083 progress.info(&format!(
3084 "Parked Qwen2.5 text encoder to CPU host RAM before denoise ({} free < {} reserve)",
3085 fmt_gb(free_after_encode),
3086 fmt_gb(required_for_residency)
3087 ));
3088 tracing::info!("Qwen2.5 text encoder parked to CPU host RAM");
3089 }
3090 Qwen2TextEncoderPostEncodeAction::Drop => {
3091 loaded.text_encoder.drop_weights();
3092 progress.info(&format!(
3093 "Dropped Qwen2.5 text encoder before denoise ({} free < {} reserve or cache hit)",
3094 fmt_gb(free_after_encode),
3095 fmt_gb(required_for_residency)
3096 ));
3097 tracing::info!("Qwen2.5 text encoder dropped from GPU");
3098 }
3099 }
3100 }
3101
3102 let vae_downsample = 8;
3104 let latent_h = height / vae_downsample;
3105 let latent_w = width / vae_downsample;
3106 let is_img2img = req.source_image.is_some();
3107
3108 let (prepared_img2img_latents, inpaint_ctx) =
3110 if let Some(ref source_bytes) = req.source_image {
3111 let encoded = Self::encode_vae_with_fallback(
3112 source_bytes,
3113 req.width,
3114 req.height,
3115 &loaded.vae,
3116 &loaded.vae_device,
3117 &loaded.device,
3118 progress,
3119 || {
3120 Ok(QwenImageVae::load(
3121 &loaded.vae_path,
3122 &Device::Cpu,
3123 DType::F32,
3124 progress,
3125 )?)
3126 },
3127 )?;
3128 let encoded = encoded.to_device(&loaded.device)?.to_dtype(loaded.dtype)?;
3129 let start_sigma = QwenImageScheduler::new_img2img(
3130 req.steps as usize,
3131 image_seq_len(latent_h, latent_w, loaded.transformer_cfg.patch_size),
3132 req.strength,
3133 )
3134 .0
3135 .initial_sigma();
3136 let prepared = crate::img2img::prepare_flow_match_img2img(
3137 &encoded,
3138 seed,
3139 &[1, 16, latent_h, latent_w],
3140 start_sigma,
3141 req.mask_image.as_deref(),
3142 latent_h,
3143 latent_w,
3144 &loaded.device,
3145 loaded.dtype,
3146 )?;
3147
3148 (Some(prepared.initial_latents), prepared.inpaint_ctx)
3149 } else {
3150 (None, None)
3151 };
3152
3153 let image_seq_len = image_seq_len(latent_h, latent_w, loaded.transformer_cfg.patch_size);
3155 let (mut scheduler, num_steps) = if is_img2img {
3156 QwenImageScheduler::new_img2img(req.steps as usize, image_seq_len, req.strength)
3157 } else {
3158 let sched = QwenImageScheduler::new(req.steps as usize, image_seq_len);
3159 let n = sched.num_steps();
3160 (sched, n)
3161 };
3162
3163 let mut latents = if let Some(initial) = &prepared_img2img_latents {
3165 initial.clone()
3166 } else {
3167 let noise = crate::engine::seeded_randn(
3168 seed,
3169 &[1, 16, latent_h, latent_w],
3170 &loaded.device,
3171 loaded.dtype,
3172 )?;
3173 (noise * scheduler.initial_sigma())?
3174 };
3175
3176 let denoise_label = format!("Denoising ({} steps)", num_steps);
3178 progress.stage_start(&denoise_label);
3179 let denoise_start = Instant::now();
3180
3181 {
3182 let transformer = loaded
3183 .transformer
3184 .as_ref()
3185 .expect("transformer must be loaded for denoising");
3186
3187 let use_batched_cfg = use_cfg && transformer.supports_cfg_batching();
3188 if use_cfg && !use_batched_cfg {
3189 progress.info(
3190 "Low-memory quantized Qwen CUDA path detected — disabling CFG batching to reduce peak CUDA memory",
3191 );
3192 }
3193
3194 let (batched_hs, batched_mask) = if use_batched_cfg {
3197 let hs = Tensor::cat(&[&encoder_hidden_states, uncond_hs.as_ref().unwrap()], 0)?;
3198 let mask =
3199 Tensor::cat(&[&encoder_attention_mask, uncond_mask.as_ref().unwrap()], 0)?;
3200 (hs, mask)
3201 } else {
3202 (
3203 encoder_hidden_states.clone(),
3204 encoder_attention_mask.clone(),
3205 )
3206 };
3207
3208 for step in 0..num_steps {
3209 let step_start = Instant::now();
3210 let t = scheduler.current_timestep();
3211 let noise_pred = if use_cfg {
3212 let (cond_pred, uncond_pred) = if use_batched_cfg {
3213 let t_tensor = Tensor::from_vec(vec![t as f32; 2], (2,), &loaded.device)?
3214 .to_dtype(loaded.dtype)?;
3215 let batched_latents = Tensor::cat(&[&latents, &latents], 0)?;
3216 let batched_pred = transformer.forward(
3217 &batched_latents,
3218 &t_tensor,
3219 &batched_hs,
3220 &batched_mask,
3221 )?;
3222 (batched_pred.narrow(0, 0, 1)?, batched_pred.narrow(0, 1, 1)?)
3223 } else {
3224 let t_tensor = Tensor::from_vec(vec![t as f32], (1,), &loaded.device)?
3225 .to_dtype(loaded.dtype)?;
3226 (
3227 transformer.forward(
3228 &latents,
3229 &t_tensor,
3230 &encoder_hidden_states,
3231 &encoder_attention_mask,
3232 )?,
3233 transformer.forward(
3234 &latents,
3235 &t_tensor,
3236 uncond_hs.as_ref().unwrap(),
3237 uncond_mask.as_ref().unwrap(),
3238 )?,
3239 )
3240 };
3241 let cond_f32 = cond_pred.to_dtype(DType::F32)?;
3243 let uncond_f32 = uncond_pred.to_dtype(DType::F32)?;
3244 let comb = (&uncond_f32 + ((&cond_f32 - &uncond_f32)? * req.guidance)?)?;
3245 let cond_norm = cond_f32.sqr()?.sum_keepdim(1)?.sqrt()?;
3246 let comb_norm = comb.sqr()?.sum_keepdim(1)?.sqrt()?.clamp(1e-8, f64::MAX)?;
3247 let rescaled = comb.broadcast_mul(&(cond_norm / comb_norm)?)?;
3248 rescaled.to_dtype(loaded.dtype)?
3249 } else {
3250 let t_tensor = Tensor::from_vec(vec![t as f32], (1,), &loaded.device)?
3251 .to_dtype(loaded.dtype)?;
3252 transformer.forward(
3253 &latents,
3254 &t_tensor,
3255 &encoder_hidden_states,
3256 &encoder_attention_mask,
3257 )?
3258 };
3259 if step == 0 || step == num_steps / 2 || step == num_steps - 1 {
3260 Self::debug_tensor_stats(&format!("noise_pred[{step}]"), &noise_pred);
3261 Self::debug_tensor_stats(&format!("latents[{step}]"), &latents);
3262 }
3263 if step == 0 {
3264 Self::validate_qwen_tensor_boundary("noise_pred[0]", &noise_pred)?;
3265 }
3266 latents = scheduler.step(&noise_pred, &latents)?;
3267 if step == num_steps - 1 {
3268 Self::validate_qwen_tensor_boundary("latents_final", &latents)?;
3269 }
3270
3271 if let Some(ref ctx) = inpaint_ctx {
3273 latents = crate::img2img::apply_flow_match_inpaint(
3274 &latents,
3275 ctx,
3276 scheduler.sigmas[step + 1],
3277 )?;
3278 }
3279
3280 progress.emit(ProgressEvent::DenoiseStep {
3281 step: step + 1,
3282 total: num_steps,
3283 elapsed: step_start.elapsed(),
3284 });
3285 }
3286 }
3287
3288 progress.stage_done(&denoise_label, denoise_start.elapsed());
3289
3290 drop(encoder_hidden_states);
3292 drop(encoder_attention_mask);
3293 drop(uncond_hs);
3294 drop(uncond_mask);
3295
3296 progress.stage_start("VAE decode");
3298 let vae_start = Instant::now();
3299 let free_for_decode = usable_free_vram_bytes(self.base.gpu_ordinal).unwrap_or(0);
3300 let prefer_tiled = Self::should_proactively_tile_vae_decode(
3301 req.width,
3302 req.height,
3303 loaded.vae_device.is_cuda(),
3304 free_for_decode,
3305 );
3306
3307 let keep_transformer_hot = Self::can_keep_transformer_hot_for_vae(loaded);
3309 let image = if keep_transformer_hot {
3310 match Self::decode_vae_gpu_only(
3311 &latents,
3312 &loaded.vae,
3313 &loaded.vae_device,
3314 &loaded.device,
3315 progress,
3316 prefer_tiled,
3317 ) {
3318 Ok(image) => {
3319 progress.info(
3320 "Kept quantized Qwen transformer resident across VAE decode for faster hot-path reuse",
3321 );
3322 image
3323 }
3324 Err(err) if Self::is_oom_error(&err) => {
3325 loaded.transformer = None;
3326 loaded.device.synchronize()?;
3327 progress.info(
3328 "Dropping Qwen-Image transformer after resident VAE decode OOM and retrying",
3329 );
3330 Self::decode_vae_with_fallback(
3331 &latents,
3332 &loaded.vae,
3333 &loaded.vae_device,
3334 &loaded.device,
3335 progress,
3336 prefer_tiled,
3337 || {
3338 QwenImageVae::load(&loaded.vae_path, &Device::Cpu, DType::F32, progress)
3339 .map_err(Into::into)
3340 },
3341 )?
3342 }
3343 Err(err) => return Err(err),
3344 }
3345 } else {
3346 loaded.transformer = None;
3347 loaded.device.synchronize()?;
3348 tracing::info!("Qwen-Image transformer dropped to free VRAM for VAE decode");
3349 Self::decode_vae_with_fallback(
3350 &latents,
3351 &loaded.vae,
3352 &loaded.vae_device,
3353 &loaded.device,
3354 progress,
3355 prefer_tiled,
3356 || {
3357 QwenImageVae::load(&loaded.vae_path, &Device::Cpu, DType::F32, progress)
3358 .map_err(Into::into)
3359 },
3360 )?
3361 };
3362 Self::validate_qwen_tensor_boundary("image_pre_postprocess", &image)?;
3363 Self::debug_tensor_stats("image_pre_postprocess", &image);
3364 let image = postprocess_image(&image)?;
3365 let post_stats = Self::validate_qwen_tensor_boundary("image_postprocess", &image)?;
3366 Self::debug_tensor_stats("image_postprocess", &image);
3367 let image = image.i(0)?;
3368 if Self::near_black_image_stats(post_stats) {
3369 progress.info(
3370 "Qwen diagnostic: decoded image is near-black after VAE postprocess; inspect MOLD_QWEN_DEBUG tensor stats to separate denoise math from VAE decode",
3371 );
3372 tracing::warn!(
3373 min = post_stats.min,
3374 max = post_stats.max,
3375 mean = post_stats.mean,
3376 "Qwen decoded image is near-black after VAE postprocess"
3377 );
3378 }
3379
3380 progress.stage_done("VAE decode", vae_start.elapsed());
3381
3382 let output_metadata = build_output_metadata(req, seed, None);
3384 let image_bytes = encode_image(
3385 &image,
3386 req.resolved_output_format(),
3387 req.width,
3388 req.height,
3389 output_metadata.as_ref(),
3390 )?;
3391
3392 let generation_time_ms = start.elapsed().as_millis() as u64;
3393 tracing::info!(generation_time_ms, seed, "Qwen-Image generation complete");
3394
3395 Ok(GenerateResponse {
3396 images: vec![ImageData {
3397 data: image_bytes,
3398 format: req.resolved_output_format(),
3399 width: req.width,
3400 height: req.height,
3401 index: 0,
3402 }],
3403 generation_time_ms,
3404 model: req.model.clone(),
3405 seed_used: seed,
3406 video: None,
3407 gpu: None,
3408 })
3409 }
3410}
3411
3412impl InferenceEngine for QwenImageEngine {
3413 fn generate(&mut self, req: &GenerateRequest) -> Result<GenerateResponse> {
3414 self.pending_placement = req.placement.clone();
3415 self.pending_loras = effective_loras(req);
3416 let result = self.generate_inner(req);
3417 self.pending_placement = None;
3418 self.pending_loras.clear();
3419 result
3420 }
3421
3422 fn model_name(&self) -> &str {
3423 self.base.model_name()
3424 }
3425
3426 fn is_loaded(&self) -> bool {
3427 self.base.is_loaded()
3428 }
3429
3430 fn load(&mut self) -> Result<()> {
3431 QwenImageEngine::load(self)
3432 }
3433
3434 fn unload(&mut self) {
3435 self.base.unload();
3436 clear_cache(&self.prompt_cache);
3437 }
3438
3439 fn set_on_progress(&mut self, callback: ProgressCallback) {
3440 self.base.set_on_progress(callback);
3441 }
3442
3443 fn clear_on_progress(&mut self) {
3444 self.base.clear_on_progress();
3445 }
3446
3447 fn model_paths(&self) -> Option<&mold_core::ModelPaths> {
3448 Some(&self.base.paths)
3449 }
3450}
3451
3452#[cfg(test)]
3453mod tests {
3454 use super::*;
3455 use crate::engine::LoadStrategy;
3456 use crate::shared_pool::SharedPool;
3457 use candle_core::Shape;
3458 use mold_core::ModelPaths;
3459 use safetensors::tensor::{serialize_to_file, Dtype as SafeDtype, TensorView};
3460 use std::collections::HashMap;
3461 use std::fs;
3462 use std::path::{Path, PathBuf};
3463 use std::sync::{Arc, Mutex};
3464 use std::time::{SystemTime, UNIX_EPOCH};
3465 use tokenizers::models::bpe::BPE;
3466
3467 fn temp_test_dir(prefix: &str) -> PathBuf {
3468 let suffix = SystemTime::now()
3469 .duration_since(UNIX_EPOCH)
3470 .unwrap()
3471 .as_nanos();
3472 let dir = std::env::temp_dir().join(format!("{prefix}-{}-{suffix}", std::process::id()));
3473 fs::create_dir_all(&dir).unwrap();
3474 dir
3475 }
3476
3477 fn touch(dir: &Path, name: &str) -> PathBuf {
3478 let path = dir.join(name);
3479 fs::write(&path, b"test").unwrap();
3480 path
3481 }
3482
3483 fn png_with_dimensions(width: u32, height: u32) -> Vec<u8> {
3484 let img = image::RgbImage::from_fn(width, height, |_, _| image::Rgb([255, 0, 0]));
3485 let mut buf = std::io::Cursor::new(Vec::new());
3486 image::DynamicImage::ImageRgb8(img)
3487 .write_to(&mut buf, image::ImageFormat::Png)
3488 .unwrap();
3489 buf.into_inner()
3490 }
3491
3492 fn qwen_image_model_paths(
3493 transformer: PathBuf,
3494 transformer_shards: Vec<PathBuf>,
3495 vae: PathBuf,
3496 text_tokenizer: Option<PathBuf>,
3497 ) -> ModelPaths {
3498 ModelPaths {
3499 transformer,
3500 transformer_shards,
3501 vae,
3502 spatial_upscaler: None,
3503 temporal_upscaler: None,
3504 distilled_lora: None,
3505 t5_encoder: None,
3506 clip_encoder: None,
3507 t5_tokenizer: None,
3508 clip_tokenizer: None,
3509 clip_encoder_2: None,
3510 clip_tokenizer_2: None,
3511 text_encoder_files: vec![],
3512 text_tokenizer,
3513 decoder: None,
3514 }
3515 }
3516
3517 fn resolved_text_encoder(is_gguf: bool, auto_use_gpu: bool) -> ResolvedQwen2TextEncoder {
3518 ResolvedQwen2TextEncoder {
3519 paths: vec![],
3520 vision_paths: vec![],
3521 is_gguf,
3522 variant_label: if is_gguf {
3523 "q6".to_string()
3524 } else {
3525 "bf16".to_string()
3526 },
3527 size_bytes: 0,
3528 auto_use_gpu,
3529 }
3530 }
3531
3532 fn tensor_values_u8(t: &Tensor) -> Vec<u8> {
3533 t.flatten_all()
3534 .unwrap()
3535 .to_vec1::<u8>()
3536 .expect("u8 tensor values")
3537 }
3538
3539 fn tensor_values_f32(t: &Tensor) -> Vec<f32> {
3540 t.flatten_all()
3541 .unwrap()
3542 .to_vec1::<f32>()
3543 .expect("f32 tensor values")
3544 }
3545
3546 #[test]
3547 fn safetensors_is_fp8_uses_filename_hint() {
3548 assert!(safetensors_is_fp8(Path::new(
3549 "/tmp/qwen-image-fp8.safetensors"
3550 )));
3551 assert!(!safetensors_is_fp8(Path::new(
3552 "/tmp/qwen-image.safetensors"
3553 )));
3554 }
3555
3556 #[test]
3557 fn text_encoder_is_fp8_uses_filename_hint() {
3558 assert!(text_encoder_is_fp8(&[PathBuf::from(
3559 "/tmp/qwen2-text-encoder-fp8-00001-of-00002.safetensors"
3560 )]));
3561 assert!(!text_encoder_is_fp8(&[PathBuf::from(
3562 "/tmp/qwen2-text-encoder-00001-of-00002.safetensors"
3563 )]));
3564 }
3565
3566 #[test]
3567 fn cached_prompt_conditioning_roundtrips_and_restores_mask() {
3568 let device = Device::Cpu;
3569 let hidden_states = Tensor::from_vec(
3570 vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0],
3571 Shape::from((1, 3, 2)),
3572 &device,
3573 )
3574 .unwrap();
3575 let cached = CachedPromptConditioning::from_parts(&hidden_states, 2).unwrap();
3576
3577 let (restored_hs, restored_mask) = cached.restore(&device, DType::F32).unwrap();
3578
3579 assert_eq!(
3580 tensor_values_f32(&restored_hs),
3581 tensor_values_f32(&hidden_states)
3582 );
3583 assert_eq!(tensor_values_u8(&restored_mask), vec![1, 1, 0]);
3584 }
3585
3586 #[test]
3587 fn pad_text_conditioning_keeps_original_when_target_matches() {
3588 let device = Device::Cpu;
3589 let hidden_states =
3590 Tensor::from_vec(vec![1.0f32, 2.0, 3.0, 4.0], Shape::from((1, 2, 2)), &device).unwrap();
3591 let mask = Tensor::from_vec(vec![1u8, 1], Shape::from((1, 2)), &device).unwrap();
3592
3593 let (padded_hs, padded_mask) = pad_text_conditioning(&hidden_states, &mask, 2).unwrap();
3594
3595 assert_eq!(
3596 tensor_values_f32(&padded_hs),
3597 tensor_values_f32(&hidden_states)
3598 );
3599 assert_eq!(tensor_values_u8(&padded_mask), vec![1, 1]);
3600 }
3601
3602 #[test]
3603 fn pad_text_conditioning_appends_zero_padding() {
3604 let device = Device::Cpu;
3605 let hidden_states =
3606 Tensor::from_vec(vec![1.0f32, 2.0, 3.0, 4.0], Shape::from((1, 2, 2)), &device).unwrap();
3607 let mask = Tensor::from_vec(vec![1u8, 0], Shape::from((1, 2)), &device).unwrap();
3608
3609 let (padded_hs, padded_mask) = pad_text_conditioning(&hidden_states, &mask, 4).unwrap();
3610
3611 assert_eq!(padded_hs.dims3().unwrap(), (1, 4, 2));
3612 assert_eq!(
3613 tensor_values_f32(&padded_hs),
3614 vec![1.0, 2.0, 3.0, 4.0, 0.0, 0.0, 0.0, 0.0]
3615 );
3616 assert_eq!(tensor_values_u8(&padded_mask), vec![1, 0, 0, 0]);
3617 }
3618
3619 #[test]
3620 fn pad_text_conditioning_rejects_shrinking() {
3621 let device = Device::Cpu;
3622 let hidden_states =
3623 Tensor::from_vec(vec![1.0f32, 2.0, 3.0, 4.0], Shape::from((1, 2, 2)), &device).unwrap();
3624 let mask = Tensor::from_vec(vec![1u8, 1], Shape::from((1, 2)), &device).unwrap();
3625
3626 let err = pad_text_conditioning(&hidden_states, &mask, 1).unwrap_err();
3627 assert!(err.to_string().contains("cannot shrink text conditioning"));
3628 }
3629
3630 #[test]
3631 fn align_cfg_conditioning_pads_shorter_branch_to_match_longer_one() {
3632 let device = Device::Cpu;
3633 let cond_hs = Tensor::from_vec(
3634 vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0],
3635 Shape::from((1, 3, 2)),
3636 &device,
3637 )
3638 .unwrap();
3639 let cond_mask = Tensor::from_vec(vec![1u8, 1, 1], Shape::from((1, 3)), &device).unwrap();
3640 let uncond_hs = Tensor::from_vec(
3641 vec![7.0f32, 8.0, 9.0, 10.0],
3642 Shape::from((1, 2, 2)),
3643 &device,
3644 )
3645 .unwrap();
3646 let uncond_mask = Tensor::from_vec(vec![1u8, 0], Shape::from((1, 2)), &device).unwrap();
3647
3648 let ((cond_hs, cond_mask), (uncond_hs, uncond_mask)) =
3649 align_cfg_conditioning(&cond_hs, &cond_mask, &uncond_hs, &uncond_mask).unwrap();
3650
3651 assert_eq!(cond_hs.dims3().unwrap(), (1, 3, 2));
3652 assert_eq!(uncond_hs.dims3().unwrap(), (1, 3, 2));
3653 assert_eq!(tensor_values_u8(&cond_mask), vec![1, 1, 1]);
3654 assert_eq!(tensor_values_u8(&uncond_mask), vec![1, 0, 0]);
3655 assert_eq!(
3656 tensor_values_f32(&uncond_hs),
3657 vec![7.0, 8.0, 9.0, 10.0, 0.0, 0.0]
3658 );
3659 }
3660
3661 #[test]
3662 fn qwen_image_detects_gguf_transformer() {
3663 let engine = QwenImageEngine::new(
3664 "qwen-image:q4".to_string(),
3665 ModelPaths {
3666 transformer: PathBuf::from("/tmp/qwen-image-Q4_K_S.gguf"),
3667 transformer_shards: vec![],
3668 vae: PathBuf::from("/tmp/vae.safetensors"),
3669 spatial_upscaler: None,
3670 temporal_upscaler: None,
3671 distilled_lora: None,
3672 t5_encoder: None,
3673 clip_encoder: None,
3674 t5_tokenizer: None,
3675 clip_tokenizer: None,
3676 clip_encoder_2: None,
3677 clip_tokenizer_2: None,
3678 text_encoder_files: vec![],
3679 text_tokenizer: Some(PathBuf::from("/tmp/tokenizer.json")),
3680 decoder: None,
3681 },
3682 LoadStrategy::Sequential,
3683 0,
3684 false,
3685 None,
3686 );
3687
3688 assert!(engine.detect_is_quantized());
3689 }
3690
3691 #[test]
3692 fn qwen_image_text_encoder_uses_gpu_on_metal() {
3693 let plan = QwenImageEngine::qwen2_text_encoder_plan_for_mode(
3694 Qwen2TextEncoderMode::Auto,
3695 false,
3696 true,
3697 &resolved_text_encoder(true, true),
3698 );
3699 assert!(plan.use_gpu);
3700 assert!(!plan.use_cpu_staging);
3701 }
3702
3703 #[test]
3704 fn qwen_image_text_encoder_uses_gpu_on_cuda_with_headroom() {
3705 let plan = QwenImageEngine::qwen2_text_encoder_plan_for_mode(
3706 Qwen2TextEncoderMode::Auto,
3707 true,
3708 false,
3709 &resolved_text_encoder(false, true),
3710 );
3711 assert!(plan.use_gpu);
3712 assert!(!plan.use_cpu_staging);
3713 }
3714
3715 #[test]
3716 fn qwen_image_text_encoder_uses_cpu_on_cuda_without_headroom() {
3717 let plan = QwenImageEngine::qwen2_text_encoder_plan_for_mode(
3718 Qwen2TextEncoderMode::Auto,
3719 true,
3720 false,
3721 &resolved_text_encoder(false, false),
3722 );
3723 assert!(!plan.use_gpu);
3724 assert!(!plan.use_cpu_staging);
3725 }
3726
3727 #[test]
3728 fn qwen_image_cpu_safetensors_text_encoder_stays_f32() {
3729 assert_eq!(
3730 QwenImageEngine::text_encoder_load_dtype(false, DType::BF16),
3731 DType::F32
3732 );
3733 }
3734
3735 #[test]
3736 fn qwen_image_cpu_gguf_text_encoder_stays_f32() {
3737 assert_eq!(
3738 QwenImageEngine::text_encoder_load_dtype(false, DType::BF16),
3739 DType::F32
3740 );
3741 }
3742
3743 #[test]
3744 fn qwen_image_text_encoder_gpu_override_disables_metal_staging() {
3745 let plan = QwenImageEngine::qwen2_text_encoder_plan_for_mode(
3746 Qwen2TextEncoderMode::Gpu,
3747 false,
3748 true,
3749 &resolved_text_encoder(true, true),
3750 );
3751 assert!(plan.use_gpu);
3752 assert!(!plan.use_cpu_staging);
3753 }
3754
3755 #[test]
3756 fn qwen_image_auto_prefers_q6_on_metal_with_headroom() {
3757 let q6 = mold_core::manifest::find_qwen2_vl_variant("q6").unwrap();
3758 let resolved = QwenImageEngine::choose_text_encoder_source(
3759 Some("auto"),
3760 false,
3761 true,
3762 qwen2_vram_threshold(q6.size_bytes) + 1,
3763 16_600_000_000,
3764 Qwen2TextEncoderUsage::Resident,
3765 )
3766 .unwrap();
3767 assert!(resolved.is_gguf);
3768 assert_eq!(resolved.variant_label, "q6");
3769 assert!(resolved.auto_use_gpu);
3770 }
3771
3772 #[test]
3773 fn qwen_image_auto_falls_back_to_q4_on_metal_when_q6_does_not_fit() {
3774 let q4 = mold_core::manifest::find_qwen2_vl_variant("q4").unwrap();
3775 let q6 = mold_core::manifest::find_qwen2_vl_variant("q6").unwrap();
3776 let free_vram = qwen2_vram_threshold(q4.size_bytes);
3777 assert!(free_vram < qwen2_vram_threshold(q6.size_bytes));
3778
3779 let resolved = QwenImageEngine::choose_text_encoder_source(
3780 Some("auto"),
3781 false,
3782 true,
3783 free_vram,
3784 0,
3785 Qwen2TextEncoderUsage::Resident,
3786 )
3787 .unwrap();
3788 assert!(resolved.is_gguf);
3789 assert_eq!(resolved.variant_label, "q4");
3790 assert!(resolved.auto_use_gpu);
3791 }
3792
3793 #[test]
3794 fn qwen_image_auto_keeps_bf16_default_on_cuda() {
3795 let resolved = QwenImageEngine::choose_text_encoder_source(
3796 Some("auto"),
3797 true,
3798 false,
3799 QWEN2_FP16_VRAM_THRESHOLD + 1,
3800 16_600_000_000,
3801 Qwen2TextEncoderUsage::Resident,
3802 )
3803 .unwrap();
3804 assert!(!resolved.is_gguf);
3805 assert_eq!(resolved.variant_label, "bf16");
3806 assert!(resolved.auto_use_gpu);
3807 }
3808
3809 #[test]
3810 fn qwen_image_auto_prefers_quantized_gpu_on_cuda_for_resident_mode_when_it_fits() {
3811 let resolved = QwenImageEngine::choose_text_encoder_source(
3812 Some("auto"),
3813 true,
3814 false,
3815 QWEN2_FP16_VRAM_THRESHOLD - 1,
3816 16_600_000_000,
3817 Qwen2TextEncoderUsage::Resident,
3818 )
3819 .unwrap();
3820 assert!(resolved.is_gguf);
3821 assert_eq!(resolved.variant_label, "q4");
3822 assert!(resolved.auto_use_gpu);
3823 }
3824
3825 #[test]
3826 fn qwen_image_auto_uses_quantized_cpu_fallback_on_cuda_for_resident_mode() {
3827 let resolved = QwenImageEngine::choose_text_encoder_source(
3828 Some("auto"),
3829 true,
3830 false,
3831 1,
3832 16_600_000_000,
3833 Qwen2TextEncoderUsage::Resident,
3834 )
3835 .unwrap();
3836 assert!(resolved.is_gguf);
3837 assert_eq!(resolved.variant_label, "q4");
3838 assert!(!resolved.auto_use_gpu);
3839 }
3840
3841 #[test]
3842 fn qwen_image_auto_prefers_quantized_gpu_on_cuda_for_sequential_mode_when_it_fits() {
3843 let resolved = QwenImageEngine::choose_text_encoder_source(
3844 Some("auto"),
3845 true,
3846 false,
3847 QWEN2_FP16_VRAM_THRESHOLD - 1,
3848 16_600_000_000,
3849 Qwen2TextEncoderUsage::Sequential,
3850 )
3851 .unwrap();
3852 assert!(resolved.is_gguf);
3853 assert_eq!(resolved.variant_label, "q4");
3854 assert!(resolved.auto_use_gpu);
3855 }
3856
3857 #[test]
3858 fn qwen_image_auto_uses_quantized_cpu_fallback_on_cuda_for_sequential_mode() {
3859 let resolved = QwenImageEngine::choose_text_encoder_source(
3860 Some("auto"),
3861 true,
3862 false,
3863 1,
3864 16_600_000_000,
3865 Qwen2TextEncoderUsage::Sequential,
3866 )
3867 .unwrap();
3868 assert!(resolved.is_gguf);
3869 assert_eq!(resolved.variant_label, "q4");
3870 assert!(!resolved.auto_use_gpu);
3871 }
3872
3873 #[test]
3874 fn qwen_image_explicit_q6_respects_cpu_fallback_on_cuda() {
3875 let resolved = QwenImageEngine::choose_text_encoder_source(
3876 Some("q6"),
3877 true,
3878 false,
3879 1,
3880 0,
3881 Qwen2TextEncoderUsage::Resident,
3882 )
3883 .unwrap();
3884 assert!(resolved.is_gguf);
3885 assert_eq!(resolved.variant_label, "q6");
3886 assert!(!resolved.auto_use_gpu);
3887 }
3888
3889 #[test]
3890 fn qwen_image_edit_accepts_quantized_text_with_bf16_vision_sidecar() {
3891 let dir = temp_test_dir("qwen-image-edit-text-encoder");
3892 let transformer = touch(&dir, "qwen-image-edit.gguf");
3893 let vae = touch(&dir, "vae.safetensors");
3894 let tokenizer = touch(&dir, "tokenizer.json");
3895 let mut paths = qwen_image_model_paths(transformer, vec![], vae, Some(tokenizer));
3896 paths.text_encoder_files = vec![touch(&dir, "text-encoder-00001-of-00004.safetensors")];
3897 let engine = QwenImageEngine::new(
3898 "qwen-image-edit-2511:q4".to_string(),
3899 paths,
3900 LoadStrategy::Sequential,
3901 0,
3902 false,
3903 None,
3904 );
3905
3906 let resolved = engine
3907 .resolve_text_encoder_source_with_preference(
3908 &Device::Cpu,
3909 0,
3910 Qwen2TextEncoderUsage::Sequential,
3911 Some("auto"),
3912 )
3913 .unwrap();
3914 assert!(!resolved.vision_paths.is_empty());
3915
3916 let resolved = engine
3917 .resolve_text_encoder_source_with_preference(
3918 &Device::Cpu,
3919 0,
3920 Qwen2TextEncoderUsage::Sequential,
3921 Some("q4"),
3922 )
3923 .unwrap();
3924 assert!(resolved.is_gguf);
3925 assert_eq!(resolved.variant_label, "q4");
3926 assert_eq!(resolved.vision_paths.len(), 1);
3927
3928 let resolved = engine
3929 .resolve_text_encoder_source_with_preference(
3930 &Device::Cpu,
3931 0,
3932 Qwen2TextEncoderUsage::Sequential,
3933 Some("bf16"),
3934 )
3935 .unwrap();
3936 assert!(!resolved.is_gguf);
3937 assert_eq!(resolved.variant_label, "bf16");
3938 assert_eq!(resolved.vision_paths.len(), 1);
3939 }
3940
3941 #[test]
3942 fn qwen_image_edit_prompt_numbers_each_picture_placeholder() {
3943 let prompt = QwenImageEngine::qwen_image_edit_prompt("swap materials", 3);
3944 assert!(prompt.contains(QWEN_IMAGE_EDIT_SYSTEM_PROMPT));
3945 assert!(prompt.contains("Picture 1: <|vision_start|><|image_pad|><|vision_end|>"));
3946 assert!(prompt.contains("Picture 2: <|vision_start|><|image_pad|><|vision_end|>"));
3947 assert!(prompt.contains("Picture 3: <|vision_start|><|image_pad|><|vision_end|>"));
3948 assert!(prompt.ends_with("<|im_start|>assistant\n"));
3949 }
3950
3951 #[test]
3952 fn qwen_image_edit_image_dims_fit_target_area_with_16px_alignment() {
3953 let bytes = png_with_dimensions(1600, 900);
3954 let (width, height) =
3955 QwenImageEngine::qwen_image_edit_image_dims(&bytes, QWEN_IMAGE_EDIT_VAE_AREA).unwrap();
3956 assert_eq!((width, height), (1360, 768));
3957 assert_eq!(width % 16, 0);
3958 assert_eq!(height % 16, 0);
3959 }
3960
3961 #[test]
3962 fn pack_and_unpack_latents_roundtrip() {
3963 let values: Vec<f32> = (0..(16 * 4 * 6)).map(|i| i as f32).collect();
3964 let latents = Tensor::from_vec(values.clone(), (1, 16, 4, 6), &Device::Cpu).unwrap();
3965 let packed = QwenImageEngine::pack_latents_4d(&latents).unwrap();
3966 assert_eq!(packed.dims3().unwrap(), (1, 6, 64));
3967
3968 let unpacked = QwenImageEngine::unpack_latents_packed(&packed, 4, 6).unwrap();
3969 assert_eq!(
3970 unpacked.flatten_all().unwrap().to_vec1::<f32>().unwrap(),
3971 values
3972 );
3973 }
3974
3975 #[test]
3976 fn quantized_cuda_cfg_headroom_scales_with_resolution() {
3977 let native = QwenImageEngine::quantized_cuda_cfg_headroom(1328, 1328);
3978 let reduced = QwenImageEngine::quantized_cuda_cfg_headroom(512, 512);
3979 assert_eq!(native, QWEN_GGUF_NATIVE_CFG_HEADROOM);
3980 assert_eq!(reduced, QWEN_GGUF_MIN_CFG_HEADROOM);
3981 }
3982
3983 #[test]
3984 fn qwen_quantized_native_resolution_uses_split_cfg_on_24gb_cuda() {
3985 assert!(QwenImageEngine::should_split_cfg_quantized_cuda(
3986 12_300_000_000,
3987 24_600_000_000,
3988 1328,
3989 1328,
3990 ));
3991 }
3992
3993 #[test]
3994 fn qwen_quantized_reduced_resolution_keeps_batched_cfg_when_it_fits() {
3995 assert!(!QwenImageEngine::should_split_cfg_quantized_cuda(
3996 12_300_000_000,
3997 24_600_000_000,
3998 512,
3999 512,
4000 ));
4001 }
4002
4003 #[test]
4004 fn qwen_quantized_cfg_split_boundary_does_not_split_when_estimate_exactly_fits() {
4005 let headroom = QwenImageEngine::quantized_cuda_cfg_headroom(1328, 1328);
4006 let transformer_size = 12_300_000_000;
4007 let free_vram = transformer_size + headroom;
4008 assert!(!QwenImageEngine::should_split_cfg_quantized_cuda(
4009 transformer_size,
4010 free_vram,
4011 1328,
4012 1328,
4013 ));
4014 }
4015
4016 #[test]
4017 fn qwen_quantized_unknown_vram_biases_to_split_cfg() {
4018 assert!(QwenImageEngine::should_split_cfg_quantized_cuda(
4019 12_300_000_000,
4020 0,
4021 1328,
4022 1328,
4023 ));
4024 }
4025
4026 #[test]
4027 fn qwen_is_oom_error_matches_cuda_memory_allocation_string() {
4028 assert!(QwenImageEngine::is_oom_error(&"cudaErrorMemoryAllocation"));
4029 }
4030
4031 #[test]
4032 fn qwen_debug_stats_counts_nan_and_inf() {
4033 let tensor = Tensor::from_vec(
4034 vec![0.0f32, 1.0, f32::NAN, f32::INFINITY, f32::NEG_INFINITY],
4035 Shape::from((5,)),
4036 &Device::Cpu,
4037 )
4038 .unwrap();
4039
4040 let stats = QwenImageEngine::tensor_stats(&tensor).unwrap();
4041
4042 assert_eq!(stats.total, 5);
4043 assert_eq!(stats.nan_count, 1);
4044 assert_eq!(stats.pos_inf_count, 1);
4045 assert_eq!(stats.neg_inf_count, 1);
4046 assert_eq!(stats.min, 0.0);
4047 assert_eq!(stats.max, 1.0);
4048 assert_eq!(stats.mean, 0.5);
4049 }
4050
4051 #[test]
4052 fn qwen_debug_stats_detects_near_black_postprocessed_image() {
4053 let stats = QwenTensorStats {
4054 min: 0.0,
4055 max: 0.01,
4056 mean: 0.004,
4057 nan_count: 0,
4058 pos_inf_count: 0,
4059 neg_inf_count: 0,
4060 total: 1024,
4061 };
4062
4063 assert!(QwenImageEngine::near_black_image_stats(stats));
4064 }
4065
4066 #[test]
4067 fn qwen_debug_stats_does_not_flag_non_black_image() {
4068 let stats = QwenTensorStats {
4069 min: 0.0,
4070 max: 0.75,
4071 mean: 0.18,
4072 nan_count: 0,
4073 pos_inf_count: 0,
4074 neg_inf_count: 0,
4075 total: 1024,
4076 };
4077
4078 assert!(!QwenImageEngine::near_black_image_stats(stats));
4079 }
4080
4081 #[test]
4082 fn qwen_debug_stats_formats_progress_message() {
4083 let stats = QwenTensorStats {
4084 min: 0.0,
4085 max: 1.0,
4086 mean: 0.5,
4087 nan_count: 2,
4088 pos_inf_count: 1,
4089 neg_inf_count: 1,
4090 total: 10,
4091 };
4092
4093 let message = QwenImageEngine::format_tensor_stats("sample", stats);
4094
4095 assert!(message.contains("NaN=2/10"));
4096 assert!(message.contains("+Inf=1"));
4097 assert!(message.contains("-Inf=1"));
4098 }
4099
4100 #[test]
4101 fn qwen_oom_fallback_returns_primary_success_without_running_fallback() {
4102 let mut progress = ProgressReporter::default();
4103 let messages = std::sync::Arc::new(std::sync::Mutex::new(Vec::<String>::new()));
4104 let messages_clone = messages.clone();
4105 progress.set_callback(Box::new(move |event| {
4106 if let ProgressEvent::Info { message } = event {
4107 messages_clone.lock().unwrap().push(message);
4108 }
4109 }));
4110
4111 let fallback_called = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false));
4112 let fallback_called_clone = fallback_called.clone();
4113 let value = QwenImageEngine::with_cuda_oom_cpu_fallback(
4114 || Ok(7usize),
4115 || {
4116 fallback_called_clone.store(true, std::sync::atomic::Ordering::SeqCst);
4117 Ok(9usize)
4118 },
4119 true,
4120 &Device::Cpu,
4121 &progress,
4122 "retrying",
4123 |_| true,
4124 )
4125 .unwrap();
4126
4127 assert_eq!(value, 7);
4128 assert!(!fallback_called.load(std::sync::atomic::Ordering::SeqCst));
4129 assert!(messages.lock().unwrap().is_empty());
4130 }
4131
4132 #[test]
4133 fn qwen_oom_fallback_retries_when_primary_ooms_on_cuda() {
4134 let mut progress = ProgressReporter::default();
4135 let messages = std::sync::Arc::new(std::sync::Mutex::new(Vec::<String>::new()));
4136 let messages_clone = messages.clone();
4137 progress.set_callback(Box::new(move |event| {
4138 if let ProgressEvent::Info { message } = event {
4139 messages_clone.lock().unwrap().push(message);
4140 }
4141 }));
4142
4143 let value = QwenImageEngine::with_cuda_oom_cpu_fallback(
4144 || Err(anyhow::anyhow!("cudaErrorMemoryAllocation")),
4145 || Ok(11usize),
4146 true,
4147 &Device::Cpu,
4148 &progress,
4149 "retrying",
4150 QwenImageEngine::is_oom_error,
4151 )
4152 .unwrap();
4153
4154 assert_eq!(value, 11);
4155 assert_eq!(messages.lock().unwrap().as_slice(), ["retrying"]);
4156 }
4157
4158 #[test]
4159 fn qwen_oom_fallback_does_not_retry_non_oom_errors() {
4160 let progress = ProgressReporter::default();
4161 let err = QwenImageEngine::with_cuda_oom_cpu_fallback(
4162 || Err(anyhow::anyhow!("not an oom")),
4163 || Ok(11usize),
4164 true,
4165 &Device::Cpu,
4166 &progress,
4167 "retrying",
4168 QwenImageEngine::is_oom_error,
4169 )
4170 .unwrap_err();
4171
4172 assert!(err.to_string().contains("not an oom"));
4173 }
4174
4175 #[test]
4176 fn qwen_tiled_fallback_returns_primary_success_without_retrying() {
4177 let progress = ProgressReporter::default();
4178 let tiled_called = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false));
4179 let cpu_called = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false));
4180 let tiled_called_clone = tiled_called.clone();
4181 let cpu_called_clone = cpu_called.clone();
4182
4183 let value = QwenImageEngine::with_cuda_tiled_then_cpu_fallback(
4184 || Ok(5usize),
4185 || {
4186 tiled_called_clone.store(true, std::sync::atomic::Ordering::SeqCst);
4187 Ok(7usize)
4188 },
4189 || {
4190 cpu_called_clone.store(true, std::sync::atomic::Ordering::SeqCst);
4191 Ok(9usize)
4192 },
4193 true,
4194 false,
4195 &Device::Cpu,
4196 &progress,
4197 "tiled",
4198 "cpu",
4199 |_| true,
4200 )
4201 .unwrap();
4202
4203 assert_eq!(value, 5);
4204 assert!(!tiled_called.load(std::sync::atomic::Ordering::SeqCst));
4205 assert!(!cpu_called.load(std::sync::atomic::Ordering::SeqCst));
4206 }
4207
4208 #[test]
4209 fn qwen_tiled_fallback_uses_tiled_result_before_cpu() {
4210 let mut progress = ProgressReporter::default();
4211 let messages = std::sync::Arc::new(std::sync::Mutex::new(Vec::<String>::new()));
4212 let messages_clone = messages.clone();
4213 progress.set_callback(Box::new(move |event| {
4214 if let ProgressEvent::Info { message } = event {
4215 messages_clone.lock().unwrap().push(message);
4216 }
4217 }));
4218
4219 let cpu_called = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false));
4220 let cpu_called_clone = cpu_called.clone();
4221 let value = QwenImageEngine::with_cuda_tiled_then_cpu_fallback(
4222 || Err(anyhow::anyhow!("out of memory")),
4223 || Ok(13usize),
4224 || {
4225 cpu_called_clone.store(true, std::sync::atomic::Ordering::SeqCst);
4226 Ok(17usize)
4227 },
4228 true,
4229 false,
4230 &Device::Cpu,
4231 &progress,
4232 "tiled",
4233 "cpu",
4234 QwenImageEngine::is_oom_error,
4235 )
4236 .unwrap();
4237
4238 assert_eq!(value, 13);
4239 assert!(!cpu_called.load(std::sync::atomic::Ordering::SeqCst));
4240 assert_eq!(messages.lock().unwrap().as_slice(), ["tiled"]);
4241 }
4242
4243 #[test]
4244 fn qwen_tiled_fallback_uses_cpu_after_tiled_oom() {
4245 let mut progress = ProgressReporter::default();
4246 let messages = std::sync::Arc::new(std::sync::Mutex::new(Vec::<String>::new()));
4247 let messages_clone = messages.clone();
4248 progress.set_callback(Box::new(move |event| {
4249 if let ProgressEvent::Info { message } = event {
4250 messages_clone.lock().unwrap().push(message);
4251 }
4252 }));
4253
4254 let value = QwenImageEngine::with_cuda_tiled_then_cpu_fallback(
4255 || Err(anyhow::anyhow!("OUT_OF_MEMORY")),
4256 || Err(anyhow::anyhow!("OUT_OF_MEMORY")),
4257 || Ok(19usize),
4258 true,
4259 false,
4260 &Device::Cpu,
4261 &progress,
4262 "tiled",
4263 "cpu",
4264 QwenImageEngine::is_oom_error,
4265 )
4266 .unwrap();
4267
4268 assert_eq!(value, 19);
4269 assert_eq!(messages.lock().unwrap().as_slice(), ["tiled", "cpu"]);
4270 }
4271
4272 #[test]
4273 fn qwen_tiled_fallback_propagates_non_oom_tiled_error() {
4274 let progress = ProgressReporter::default();
4275 let err = QwenImageEngine::with_cuda_tiled_then_cpu_fallback(
4276 || Err(anyhow::anyhow!("out of memory")),
4277 || Err(anyhow::anyhow!("bad tiled decode")),
4278 || Ok(19usize),
4279 true,
4280 false,
4281 &Device::Cpu,
4282 &progress,
4283 "tiled",
4284 "cpu",
4285 QwenImageEngine::is_oom_error,
4286 )
4287 .unwrap_err();
4288
4289 assert!(err.to_string().contains("bad tiled decode"));
4290 }
4291
4292 #[test]
4293 fn qwen_proactive_tiled_policy_selects_native_cuda_under_pressure() {
4294 assert!(QwenImageEngine::should_proactively_tile_vae_decode(
4295 1328,
4296 1328,
4297 true,
4298 6_000_000_000
4299 ));
4300 assert!(!QwenImageEngine::should_proactively_tile_vae_decode(
4301 512,
4302 512,
4303 true,
4304 6_000_000_000
4305 ));
4306 assert!(!QwenImageEngine::should_proactively_tile_vae_decode(
4307 1328,
4308 1328,
4309 false,
4310 6_000_000_000
4311 ));
4312 assert!(!QwenImageEngine::should_proactively_tile_vae_decode(
4313 1328,
4314 1328,
4315 true,
4316 16_000_000_000
4317 ));
4318 }
4319
4320 #[test]
4321 fn qwen_proactive_tiled_decode_skips_primary_full_decode() {
4322 let mut progress = ProgressReporter::default();
4323 let messages = std::sync::Arc::new(std::sync::Mutex::new(Vec::<String>::new()));
4324 let messages_clone = messages.clone();
4325 progress.set_callback(Box::new(move |event| {
4326 if let ProgressEvent::Info { message } = event {
4327 messages_clone.lock().unwrap().push(message);
4328 }
4329 }));
4330
4331 let primary_called = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false));
4332 let primary_called_clone = primary_called.clone();
4333 let value = QwenImageEngine::with_cuda_tiled_then_cpu_fallback(
4334 || {
4335 primary_called_clone.store(true, std::sync::atomic::Ordering::SeqCst);
4336 Ok(3usize)
4337 },
4338 || Ok(7usize),
4339 || Ok(9usize),
4340 true,
4341 true,
4342 &Device::Cpu,
4343 &progress,
4344 "tiled after oom",
4345 "cpu",
4346 QwenImageEngine::is_oom_error,
4347 )
4348 .unwrap();
4349
4350 assert_eq!(value, 7);
4351 assert!(!primary_called.load(std::sync::atomic::Ordering::SeqCst));
4352 assert_eq!(
4353 messages.lock().unwrap().as_slice(),
4354 ["Selecting tiled GPU VAE decode proactively"]
4355 );
4356 }
4357
4358 #[test]
4359 fn qwen_hot_text_encoder_keeps_gpu_after_cache_miss_with_headroom() {
4360 let action = QwenImageEngine::qwen2_text_encoder_post_encode_action(
4361 Qwen2TextEncoderResidencyInput {
4362 on_gpu: true,
4363 is_quantized: true,
4364 is_metal: false,
4365 keep_te_ram: false,
4366 prompt_cache_miss: true,
4367 transformer_resident: true,
4368 free_vram_bytes: 10_000_000_000,
4369 required_vram_bytes: 8_000_000_000,
4370 },
4371 );
4372
4373 assert_eq!(action, Qwen2TextEncoderPostEncodeAction::KeepGpu);
4374 }
4375
4376 #[test]
4377 fn qwen_hot_text_encoder_drops_after_cache_hit_even_with_headroom() {
4378 let action = QwenImageEngine::qwen2_text_encoder_post_encode_action(
4379 Qwen2TextEncoderResidencyInput {
4380 on_gpu: true,
4381 is_quantized: true,
4382 is_metal: false,
4383 keep_te_ram: false,
4384 prompt_cache_miss: false,
4385 transformer_resident: true,
4386 free_vram_bytes: 10_000_000_000,
4387 required_vram_bytes: 8_000_000_000,
4388 },
4389 );
4390
4391 assert_eq!(action, Qwen2TextEncoderPostEncodeAction::Drop);
4392 }
4393
4394 #[test]
4395 fn qwen_hot_text_encoder_drops_under_transformer_pressure() {
4396 let action = QwenImageEngine::qwen2_text_encoder_post_encode_action(
4397 Qwen2TextEncoderResidencyInput {
4398 on_gpu: true,
4399 is_quantized: true,
4400 is_metal: false,
4401 keep_te_ram: false,
4402 prompt_cache_miss: true,
4403 transformer_resident: true,
4404 free_vram_bytes: 7_999_999_999,
4405 required_vram_bytes: 8_000_000_000,
4406 },
4407 );
4408
4409 assert_eq!(action, Qwen2TextEncoderPostEncodeAction::Drop);
4410 }
4411
4412 #[test]
4413 fn qwen_hot_text_encoder_parks_bf16_when_keep_ram_enabled() {
4414 let action = QwenImageEngine::qwen2_text_encoder_post_encode_action(
4415 Qwen2TextEncoderResidencyInput {
4416 on_gpu: true,
4417 is_quantized: false,
4418 is_metal: false,
4419 keep_te_ram: true,
4420 prompt_cache_miss: true,
4421 transformer_resident: true,
4422 free_vram_bytes: 7_999_999_999,
4423 required_vram_bytes: 8_000_000_000,
4424 },
4425 );
4426
4427 assert_eq!(action, Qwen2TextEncoderPostEncodeAction::ParkCpu);
4428 }
4429
4430 #[test]
4431 fn qwen_hot_text_encoder_never_parks_quantized() {
4432 let action = QwenImageEngine::qwen2_text_encoder_post_encode_action(
4433 Qwen2TextEncoderResidencyInput {
4434 on_gpu: true,
4435 is_quantized: true,
4436 is_metal: false,
4437 keep_te_ram: true,
4438 prompt_cache_miss: true,
4439 transformer_resident: true,
4440 free_vram_bytes: 7_999_999_999,
4441 required_vram_bytes: 8_000_000_000,
4442 },
4443 );
4444
4445 assert_eq!(action, Qwen2TextEncoderPostEncodeAction::Drop);
4446 }
4447
4448 #[test]
4449 fn qwen_hot_text_encoder_drops_when_transformer_not_resident() {
4450 let action = QwenImageEngine::qwen2_text_encoder_post_encode_action(
4451 Qwen2TextEncoderResidencyInput {
4452 on_gpu: true,
4453 is_quantized: true,
4454 is_metal: false,
4455 keep_te_ram: false,
4456 prompt_cache_miss: true,
4457 transformer_resident: false,
4458 free_vram_bytes: 10_000_000_000,
4459 required_vram_bytes: 8_000_000_000,
4460 },
4461 );
4462
4463 assert_eq!(action, Qwen2TextEncoderPostEncodeAction::Drop);
4464 }
4465
4466 #[test]
4467 fn qwen_transformer_hot_vae_eligibility_requires_quantized_cuda_components() {
4468 assert!(QwenImageEngine::qwen_transformer_can_stay_hot_for_vae(
4469 true, true, true
4470 ));
4471 assert!(!QwenImageEngine::qwen_transformer_can_stay_hot_for_vae(
4472 false, true, true
4473 ));
4474 assert!(!QwenImageEngine::qwen_transformer_can_stay_hot_for_vae(
4475 true, false, true
4476 ));
4477 assert!(!QwenImageEngine::qwen_transformer_can_stay_hot_for_vae(
4478 true, true, false
4479 ));
4480 }
4481
4482 #[test]
4483 fn qwen_transformer_paths_prefer_shards_when_present() {
4484 let dir = temp_test_dir("mold-qwen-shards");
4485 let shard_a = touch(&dir, "transformer-00001-of-00002.safetensors");
4486 let shard_b = touch(&dir, "transformer-00002-of-00002.safetensors");
4487 let engine = QwenImageEngine::new(
4488 "qwen-image:q4".to_string(),
4489 qwen_image_model_paths(
4490 dir.join("transformer.safetensors"),
4491 vec![shard_a.clone(), shard_b.clone()],
4492 dir.join("vae.safetensors"),
4493 Some(dir.join("tokenizer.json")),
4494 ),
4495 LoadStrategy::Sequential,
4496 0,
4497 false,
4498 None,
4499 );
4500
4501 assert_eq!(engine.transformer_paths(), vec![shard_a, shard_b]);
4502
4503 fs::remove_dir_all(dir).ok();
4504 }
4505
4506 #[test]
4507 fn qwen_validate_paths_accepts_existing_files() {
4508 let dir = temp_test_dir("mold-qwen-validate-ok");
4509 let shard_a = touch(&dir, "transformer-00001-of-00002.safetensors");
4510 let shard_b = touch(&dir, "transformer-00002-of-00002.safetensors");
4511 let vae = touch(&dir, "vae.safetensors");
4512 let tokenizer = touch(&dir, "tokenizer.json");
4513 let gguf = touch(&dir, "transformer.gguf");
4514
4515 let sharded = QwenImageEngine::new(
4516 "qwen-image:bf16".to_string(),
4517 qwen_image_model_paths(
4518 dir.join("transformer.safetensors"),
4519 vec![shard_a, shard_b],
4520 vae.clone(),
4521 Some(tokenizer.clone()),
4522 ),
4523 LoadStrategy::Sequential,
4524 0,
4525 false,
4526 None,
4527 );
4528 assert_eq!(sharded.validate_paths().unwrap(), tokenizer);
4529 assert!(!sharded.detect_is_quantized());
4530
4531 let quantized = QwenImageEngine::new(
4532 "qwen-image:q4".to_string(),
4533 qwen_image_model_paths(gguf, vec![], vae, Some(dir.join("tokenizer.json"))),
4534 LoadStrategy::Sequential,
4535 0,
4536 false,
4537 None,
4538 );
4539 assert!(quantized.detect_is_quantized());
4540
4541 fs::remove_dir_all(dir).ok();
4542 }
4543
4544 #[test]
4545 fn qwen_validate_paths_requires_text_tokenizer() {
4546 let dir = temp_test_dir("mold-qwen-validate-missing");
4547 let engine = QwenImageEngine::new(
4548 "qwen-image:q4".to_string(),
4549 qwen_image_model_paths(
4550 dir.join("transformer.gguf"),
4551 vec![],
4552 dir.join("vae.safetensors"),
4553 None,
4554 ),
4555 LoadStrategy::Sequential,
4556 0,
4557 false,
4558 None,
4559 );
4560
4561 let err = engine.validate_paths().unwrap_err();
4562 assert!(err.to_string().contains("text tokenizer path required"));
4563
4564 fs::remove_dir_all(dir).ok();
4565 }
4566
4567 #[test]
4568 fn qwen_image_loads_text_tokenizer_through_shared_pool() {
4569 let dir = temp_test_dir("mold-qwen-tokenizer-pool");
4570 let tokenizer_path = dir.join("tokenizer.json");
4571 tokenizers::Tokenizer::new(BPE::default())
4572 .save(&tokenizer_path, false)
4573 .unwrap();
4574
4575 let shared_pool = Arc::new(Mutex::new(SharedPool::new()));
4576 let pooled = shared_pool
4577 .lock()
4578 .unwrap()
4579 .load_tokenizer(&tokenizer_path)
4580 .unwrap();
4581
4582 let engine = QwenImageEngine::new(
4583 "qwen-image:q4".to_string(),
4584 qwen_image_model_paths(
4585 dir.join("transformer.gguf"),
4586 vec![],
4587 dir.join("vae.safetensors"),
4588 Some(tokenizer_path.clone()),
4589 ),
4590 LoadStrategy::Sequential,
4591 0,
4592 false,
4593 Some(shared_pool),
4594 );
4595
4596 let loaded = engine.load_text_tokenizer(&tokenizer_path).unwrap();
4597
4598 assert!(Arc::ptr_eq(&pooled, &loaded));
4599 fs::remove_dir_all(dir).ok();
4600 }
4601
4602 #[test]
4603 fn qwen_image_loads_vae_tensors_through_shared_pool() {
4604 let dir = temp_test_dir("mold-qwen-vae-pool");
4605 let vae_path = dir.join("vae.safetensors");
4606 let weight = 1.0f32.to_le_bytes();
4607 let mut tensors = HashMap::new();
4608 tensors.insert(
4609 "encoder.conv_in.weight".to_string(),
4610 TensorView::new(SafeDtype::F32, vec![1], &weight).unwrap(),
4611 );
4612 serialize_to_file(&tensors, &None, &vae_path).unwrap();
4613
4614 let shared_pool = Arc::new(Mutex::new(SharedPool::new()));
4615 let pooled = shared_pool
4616 .lock()
4617 .unwrap()
4618 .load_safetensors_cpu_tensors(std::slice::from_ref(&vae_path))
4619 .unwrap()
4620 .unwrap();
4621
4622 let engine = QwenImageEngine::new(
4623 "qwen-image:q4".to_string(),
4624 qwen_image_model_paths(
4625 dir.join("transformer.gguf"),
4626 vec![],
4627 vae_path.clone(),
4628 Some(dir.join("tokenizer.json")),
4629 ),
4630 LoadStrategy::Sequential,
4631 0,
4632 false,
4633 Some(shared_pool),
4634 );
4635
4636 let loaded = engine.load_vae_cpu_tensors().unwrap().unwrap();
4637
4638 assert!(Arc::ptr_eq(&pooled, &loaded));
4639 fs::remove_dir_all(dir).ok();
4640 }
4641
4642 #[test]
4643 fn qwen_img2img_uses_minus_one_to_one_source_normalization() {
4644 assert_eq!(
4645 QwenImageEngine::img2img_source_normalize_range(),
4646 img_utils::NormalizeRange::MinusOneToOne
4647 );
4648 }
4649}