Skip to main content

mold_inference/qwen_image/
pipeline.rs

1//! Qwen-Image-2512 inference engine.
2//!
3//! Pipeline: Qwen2.5-VL text encoder -> QwenImageTransformer2DModel -> QwenImage VAE
4//!
5//! Architecture follows Z-Image closely (both from Alibaba/Tongyi):
6//! - Dual-stream transformer with joint attention and 3D RoPE
7//! - Flow-matching Euler discrete scheduler with dynamic shifting
8//! - Drop-and-reload for text encoder to manage VRAM
9//! - Both Eager and Sequential loading modes
10//!
11//! Key differences from Z-Image:
12//! - 60 identical dual-stream blocks (no noise_refiner/context_refiner)
13//! - Qwen2.5-VL text encoder (hidden_size=3584) instead of Qwen3 (2560)
14//! - Custom VAE with per-channel latent normalization
15//! - Official diffusers-style exponential time shift with dynamic per-image stretch
16
17use anyhow::{bail, Result};
18use candle_core::{DType, Device, IndexOp, Tensor, D};
19use candle_transformers::models::z_image::postprocess_image;
20use candle_transformers::quantized_var_builder;
21use mold_core::{fit_to_target_area, GenerateRequest, GenerateResponse, ImageData, ModelPaths};
22use std::collections::HashMap;
23use std::path::Path;
24use std::sync::{Arc, Mutex};
25use std::time::Instant;
26use tokenizers::Tokenizer;
27
28use super::quantized_transformer::QuantizedQwenImageTransformer2DModel;
29use super::sampling::{image_seq_len, QwenImageScheduler};
30use super::transformer::{QwenImageConfig, QwenImageTransformer2DModel};
31use super::vae::QwenImageVae;
32use crate::cache::{
33    clear_cache, prompt_text_key, CachedTensor, LruCache, DEFAULT_PROMPT_CACHE_CAPACITY,
34};
35use crate::device::{
36    effective_device_ref, fits_in_memory, fmt_gb, free_vram_bytes, memory_status_string,
37    preflight_memory_check, qwen2_vram_threshold, should_use_gpu, usable_free_vram_bytes,
38};
39use crate::encoders;
40use crate::engine::{rand_seed, InferenceEngine, LoadStrategy};
41use crate::engine_base::EngineBase;
42use crate::image::{build_output_metadata, encode_image};
43use crate::img_utils;
44use crate::progress::{ProgressCallback, ProgressEvent, ProgressReporter};
45use crate::upscaler::tiling::{upscale_with_tiling, TilingConfig};
46
47/// Minimum free VRAM (bytes) required to place Qwen-Image VAE on GPU.
48/// The VAE weights are ~300MB; decode workspace at 1024x1024 needs ~1-2GB.
49const VAE_DECODE_VRAM_THRESHOLD: u64 = 2_500_000_000;
50// Use a single space rather than an empty string so the unconditional CFG path
51// stays explicit after Qwen prompt templating and token windowing.
52const QWEN_EMPTY_NEGATIVE_PROMPT: &str = " ";
53const QWEN_NATIVE_WIDTH: usize = 1328;
54const QWEN_NATIVE_HEIGHT: usize = 1328;
55const QWEN_GGUF_NATIVE_CFG_HEADROOM: u64 = 14_000_000_000;
56const QWEN_GGUF_MIN_CFG_HEADROOM: u64 = 3_000_000_000;
57const QWEN_VAE_TILE_SIZES: [u32; 3] = [64, 32, 16];
58const QWEN_IMAGE_EDIT_VAE_AREA: u32 = 1024 * 1024;
59const QWEN_IMAGE_EDIT_SYSTEM_PROMPT: &str = "Describe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.";
60
61/// Minimum free VRAM for BF16 Qwen2.5-VL 7B text encoder on GPU.
62/// ~14GB model + 2GB headroom.
63const QWEN2_FP16_VRAM_THRESHOLD: u64 = 16_000_000_000;
64/// Extra residual VRAM required before keeping Qwen2.5 on GPU after a prompt
65/// cache miss. The denoise/VAE reserves cover known workspaces; this absorbs
66/// allocator fragmentation and backend scratch buffers.
67const QWEN2_HOT_TE_RESIDENCY_HEADROOM: u64 = 1_000_000_000;
68
69#[derive(Debug, Clone, Copy, PartialEq, Eq)]
70enum Qwen2TextEncoderMode {
71    Auto,
72    Gpu,
73    CpuStage,
74    Cpu,
75}
76
77impl Qwen2TextEncoderMode {
78    fn from_env() -> Self {
79        match std::env::var("MOLD_QWEN2_TEXT_ENCODER_MODE")
80            .unwrap_or_default()
81            .to_ascii_lowercase()
82            .as_str()
83        {
84            "gpu" => Self::Gpu,
85            "cpu-stage" => Self::CpuStage,
86            "cpu_stage" => Self::CpuStage,
87            "cpu" => Self::Cpu,
88            _ => Self::Auto,
89        }
90    }
91}
92
93#[derive(Debug, Clone, Copy, PartialEq, Eq)]
94struct Qwen2TextEncoderPlan {
95    use_gpu: bool,
96    use_cpu_staging: bool,
97}
98
99#[derive(Debug, Clone)]
100struct ResolvedQwen2TextEncoder {
101    paths: Vec<std::path::PathBuf>,
102    vision_paths: Vec<std::path::PathBuf>,
103    is_gguf: bool,
104    variant_label: String,
105    size_bytes: u64,
106    auto_use_gpu: bool,
107}
108
109#[derive(Debug, Clone, Copy, PartialEq, Eq)]
110enum Qwen2TextEncoderUsage {
111    Sequential,
112    Resident,
113}
114
115#[derive(Debug, Clone, Copy, PartialEq, Eq)]
116enum Qwen2TextEncoderPostEncodeAction {
117    KeepGpu,
118    ParkCpu,
119    Drop,
120}
121
122#[derive(Debug, Clone, Copy)]
123struct Qwen2TextEncoderResidencyInput {
124    on_gpu: bool,
125    is_quantized: bool,
126    is_metal: bool,
127    keep_te_ram: bool,
128    prompt_cache_miss: bool,
129    transformer_resident: bool,
130    free_vram_bytes: u64,
131    required_vram_bytes: u64,
132}
133
134#[derive(Debug, Clone, Copy)]
135struct QwenTensorStats {
136    min: f32,
137    max: f32,
138    mean: f32,
139    nan_count: u64,
140    pos_inf_count: u64,
141    neg_inf_count: u64,
142    total: usize,
143}
144
145/// Check if a Qwen-Image safetensors checkpoint stores weights in FP8 (F8_E4M3).
146/// Uses filename pattern first, then dtype probing as fallback.
147fn safetensors_is_fp8(path: &Path) -> bool {
148    // Filename-based detection
149    if path.to_str().map(|s| s.contains("fp8")).unwrap_or(false) {
150        return true;
151    }
152    // Dtype probing — try both ComfyUI and diffusers key names
153    let Ok(tensors) = (unsafe { candle_core::safetensors::MmapedSafetensors::multi(&[path]) })
154    else {
155        return false;
156    };
157    for key in ["x_embedder.weight", "img_in.weight"] {
158        if let Ok(t) = tensors.load(key, &Device::Cpu) {
159            return t.dtype() == DType::F8E4M3;
160        }
161    }
162    false
163}
164
165/// Check if text encoder safetensors contain FP8 weights.
166/// Uses filename pattern first (reliable for known ComfyUI FP8 models),
167/// then falls back to dtype probing.
168fn text_encoder_is_fp8(paths: &[std::path::PathBuf]) -> bool {
169    // Filename-based detection (ComfyUI FP8 models have "fp8" in name)
170    if paths
171        .iter()
172        .any(|p| p.to_str().map(|s| s.contains("fp8")).unwrap_or(false))
173    {
174        return true;
175    }
176    // Dtype probing fallback — try common key names
177    let Some(first) = paths.first() else {
178        return false;
179    };
180    let Ok(tensors) = (unsafe { candle_core::safetensors::MmapedSafetensors::multi(&[first]) })
181    else {
182        return false;
183    };
184    for key in [
185        "model.embed_tokens.weight",
186        "model.layers.0.self_attn.q_proj.weight",
187    ] {
188        if let Ok(t) = tensors.load(key, &Device::Cpu) {
189            return t.dtype() == DType::F8E4M3;
190        }
191    }
192    false
193}
194
195/// Loaded Qwen-Image model components, ready for inference.
196struct LoadedQwenImage {
197    /// Transformer wrapped in Option for drop-and-reload pattern.
198    transformer: Option<QwenImageTransformer>,
199    text_encoder: encoders::qwen2_text::Qwen2TextEncoder,
200    vae: QwenImageVae,
201    vae_path: std::path::PathBuf,
202    transformer_cfg: QwenImageConfig,
203    /// GPU device for transformer + denoising
204    device: Device,
205    /// Device where the VAE lives (may be CPU if VRAM is tight)
206    vae_device: Device,
207    dtype: DType,
208}
209
210#[allow(clippy::large_enum_variant)]
211enum QwenImageTransformer {
212    BF16(QwenImageTransformer2DModel),
213    Quantized(QuantizedQwenImageTransformer2DModel),
214    Offloaded(super::offload::OffloadedQwenImageTransformer),
215}
216
217#[derive(Clone)]
218struct CachedPromptConditioning {
219    hidden_states: CachedTensor,
220    valid_len: usize,
221}
222
223impl CachedPromptConditioning {
224    fn from_parts(hidden_states: &Tensor, valid_len: usize) -> Result<Self> {
225        Ok(Self {
226            hidden_states: CachedTensor::from_tensor(hidden_states)?,
227            valid_len,
228        })
229    }
230
231    fn restore(&self, device: &Device, dtype: DType) -> Result<(Tensor, Tensor)> {
232        let hidden_states = self.hidden_states.restore(device, dtype)?;
233        let mut mask = vec![0u8; hidden_states.dim(1)?];
234        for value in &mut mask[..self.valid_len] {
235            *value = 1;
236        }
237        let attention_mask = Tensor::from_vec(mask, (1, hidden_states.dim(1)?), device)?;
238        Ok((hidden_states, attention_mask))
239    }
240}
241
242fn pad_text_conditioning(
243    hidden_states: &Tensor,
244    attention_mask: &Tensor,
245    target_len: usize,
246) -> Result<(Tensor, Tensor)> {
247    let seq_len = hidden_states.dim(1)?;
248    if seq_len == target_len {
249        return Ok((hidden_states.clone(), attention_mask.clone()));
250    }
251    if seq_len > target_len {
252        bail!("cannot shrink text conditioning from {seq_len} to {target_len}");
253    }
254
255    let hidden_dim = hidden_states.dim(2)?;
256    let pad_len = target_len - seq_len;
257    let pad_hs = Tensor::zeros(
258        (hidden_states.dim(0)?, pad_len, hidden_dim),
259        hidden_states.dtype(),
260        hidden_states.device(),
261    )?;
262    let pad_mask = Tensor::zeros(
263        (attention_mask.dim(0)?, pad_len),
264        attention_mask.dtype(),
265        attention_mask.device(),
266    )?;
267
268    Ok((
269        Tensor::cat(&[hidden_states, &pad_hs], 1)?,
270        Tensor::cat(&[attention_mask, &pad_mask], 1)?,
271    ))
272}
273
274fn align_cfg_conditioning(
275    cond_hs: &Tensor,
276    cond_mask: &Tensor,
277    uncond_hs: &Tensor,
278    uncond_mask: &Tensor,
279) -> Result<((Tensor, Tensor), (Tensor, Tensor))> {
280    let target_len = cond_hs.dim(1)?.max(uncond_hs.dim(1)?);
281    let cond = pad_text_conditioning(cond_hs, cond_mask, target_len)?;
282    let uncond = pad_text_conditioning(uncond_hs, uncond_mask, target_len)?;
283    Ok((cond, uncond))
284}
285
286impl QwenImageTransformer {
287    fn supports_cfg_batching(&self) -> bool {
288        match self {
289            Self::Quantized(model) => model.supports_cfg_batching(),
290            _ => true,
291        }
292    }
293
294    fn forward(
295        &self,
296        latents: &Tensor,
297        t: &Tensor,
298        encoder_hidden_states: &Tensor,
299        encoder_attention_mask: &Tensor,
300    ) -> Result<Tensor> {
301        match self {
302            Self::BF16(model) => {
303                Ok(model.forward(latents, t, encoder_hidden_states, encoder_attention_mask)?)
304            }
305            Self::Quantized(model) => {
306                Ok(model.forward(latents, t, encoder_hidden_states, encoder_attention_mask)?)
307            }
308            Self::Offloaded(model) => {
309                model.forward(latents, t, encoder_hidden_states, encoder_attention_mask)
310            }
311        }
312    }
313
314    fn forward_packed(
315        &self,
316        packed_latents: &Tensor,
317        t: &Tensor,
318        encoder_hidden_states: &Tensor,
319        encoder_attention_mask: &Tensor,
320        img_shapes: &[(usize, usize, usize)],
321    ) -> Result<Tensor> {
322        match self {
323            Self::BF16(model) => Ok(model.forward_packed(
324                packed_latents,
325                t,
326                encoder_hidden_states,
327                encoder_attention_mask,
328                img_shapes,
329            )?),
330            Self::Quantized(model) => Ok(model.forward_packed(
331                packed_latents,
332                t,
333                encoder_hidden_states,
334                encoder_attention_mask,
335                img_shapes,
336            )?),
337            Self::Offloaded(model) => model.forward_packed(
338                packed_latents,
339                t,
340                encoder_hidden_states,
341                encoder_attention_mask,
342                img_shapes,
343            ),
344        }
345    }
346}
347
348/// Qwen-Image-2512 inference engine.
349pub struct QwenImageEngine {
350    base: EngineBase<LoadedQwenImage>,
351    prompt_cache: Mutex<LruCache<String, CachedPromptConditioning>>,
352    offload: bool,
353    /// Per-request placement override.
354    pending_placement: Option<mold_core::types::DevicePlacement>,
355    /// Per-request LoRA stack. Captured at the start of `generate()`,
356    /// cleared on exit. The transformer-load path consults this when
357    /// constructing the `VarBuilder` so the LoRA-merged weights land
358    /// before any forward pass runs.
359    pending_loras: Vec<mold_core::LoraWeight>,
360    /// Fingerprint of the LoRA stack currently baked into the loaded
361    /// transformer. Eager-mode generates compare against this to decide
362    /// whether to rebuild — an unchanged stack reuses the previously
363    /// merged weights. Currently always recomputed at load time
364    /// (same correctness-first stance as the sibling flux2 / sd3 / sdxl
365    /// / z-image early ports); the fingerprint API is in place for the
366    /// rebuild-elision follow-up.
367    #[allow(dead_code)]
368    active_lora_fingerprint: Vec<QwenImageLoraFingerprint>,
369    shared_pool: Option<Arc<Mutex<crate::shared_pool::SharedPool>>>,
370}
371
372/// Order-sensitive fingerprint of a single LoRA adapter (path-hash + scale).
373#[derive(Clone, PartialEq, Eq, Debug)]
374#[allow(dead_code)]
375struct QwenImageLoraFingerprint {
376    path_hash: u64,
377    scale_bits: u64,
378}
379
380impl QwenImageLoraFingerprint {
381    #[allow(dead_code)]
382    fn from_lora(lora: &mold_core::LoraWeight) -> Self {
383        Self {
384            path_hash: super::lora::lora_path_hash(&lora.path),
385            scale_bits: lora.scale.to_bits(),
386        }
387    }
388}
389
390#[allow(dead_code)]
391fn fingerprint_stack(loras: &[mold_core::LoraWeight]) -> Vec<QwenImageLoraFingerprint> {
392    loras
393        .iter()
394        .map(QwenImageLoraFingerprint::from_lora)
395        .collect()
396}
397
398/// Resolve the effective LoRA list for a request. Mirrors `flux::pipeline::
399/// effective_loras` — accepts both `lora` (singular, legacy) and `loras`
400/// (plural, current). Entries with a near-zero scale are dropped.
401fn effective_loras(req: &mold_core::GenerateRequest) -> Vec<mold_core::LoraWeight> {
402    /// Match the FLUX threshold so the user-facing semantics are
403    /// identical across families.
404    const ZERO_SCALE_EPS: f64 = 1e-8;
405
406    let raw: Vec<mold_core::LoraWeight> = if let Some(plural) = &req.loras {
407        if !plural.is_empty() {
408            plural.clone()
409        } else {
410            req.lora.iter().cloned().collect()
411        }
412    } else {
413        req.lora.iter().cloned().collect()
414    };
415
416    raw.into_iter()
417        .filter(|w| {
418            let keep = w.scale.abs() > ZERO_SCALE_EPS;
419            if !keep {
420                tracing::debug!(
421                    path = w.path.as_str(),
422                    scale = w.scale,
423                    "dropping zero-scale LoRA from effective Qwen-Image stack"
424                );
425            }
426            keep
427        })
428        .collect()
429}
430
431impl QwenImageEngine {
432    fn is_edit_family(&self) -> bool {
433        self.base.model_name.starts_with("qwen-image-edit")
434    }
435
436    fn should_preload_text_encoder(&self) -> bool {
437        !self.is_edit_family()
438    }
439
440    fn text_encoder_load_dtype(use_gpu: bool, gpu_dtype: DType) -> DType {
441        if use_gpu {
442            gpu_dtype
443        } else {
444            // Candle CPU matmul does not support BF16 for the Qwen2.5 encoder path.
445            // Keep CPU language/vision encoding in F32 and use quantized GGUF when
446            // lower host residency is needed.
447            DType::F32
448        }
449    }
450
451    fn transformer_config(&self) -> QwenImageConfig {
452        if self.is_edit_family() {
453            QwenImageConfig::qwen_image_edit_2511()
454        } else {
455            QwenImageConfig::qwen_image_2512()
456        }
457    }
458
459    fn qwen_image_edit_prompt(prompt: &str, image_count: usize) -> String {
460        let picture_prefix = (0..image_count)
461            .map(|idx| {
462                format!(
463                    "Picture {}: <|vision_start|><|image_pad|><|vision_end|>",
464                    idx + 1
465                )
466            })
467            .collect::<String>();
468        format!(
469            "<|im_start|>system\n{QWEN_IMAGE_EDIT_SYSTEM_PROMPT}<|im_end|>\n<|im_start|>user\n{picture_prefix}{prompt}<|im_end|>\n<|im_start|>assistant\n"
470        )
471    }
472
473    fn qwen_image_edit_image_dims(image: &[u8], target_area: u32) -> Result<(u32, u32)> {
474        let img = image::load_from_memory(image)?;
475        Ok(fit_to_target_area(
476            img.width().max(1),
477            img.height().max(1),
478            target_area,
479            16,
480        ))
481    }
482
483    fn pack_latents_4d(latents: &Tensor) -> Result<Tensor> {
484        let (batch, channels, height, width) = latents.dims4()?;
485        let height_blocks = height / 2;
486        let width_blocks = width / 2;
487        latents
488            .reshape((batch, channels, height_blocks, 2, width_blocks, 2))?
489            .permute((0, 2, 4, 1, 3, 5))?
490            .reshape((batch, height_blocks * width_blocks, channels * 4))
491            .map_err(Into::into)
492    }
493
494    fn unpack_latents_packed(latents: &Tensor, latent_h: usize, latent_w: usize) -> Result<Tensor> {
495        let batch = latents.dim(0)?;
496        latents
497            .reshape((batch, latent_h / 2, latent_w / 2, 16, 2, 2))?
498            .permute((0, 3, 1, 4, 2, 5))?
499            .reshape((batch, 16, latent_h, latent_w))
500            .map_err(Into::into)
501    }
502
503    fn img2img_source_normalize_range() -> img_utils::NormalizeRange {
504        img_utils::NormalizeRange::MinusOneToOne
505    }
506
507    fn is_oom_error(err: &impl std::fmt::Display) -> bool {
508        // TODO: Replace this with typed backend inspection if candle exposes
509        // one. Today the fallback ladder has to key off the backend error text.
510        let msg = err.to_string();
511        msg.contains("OUT_OF_MEMORY")
512            || msg.contains("out of memory")
513            || msg.contains("cudaErrorMemoryAllocation")
514    }
515
516    fn with_cuda_oom_cpu_fallback<T, FPrimary, FFallback, FOom>(
517        primary: FPrimary,
518        fallback: FFallback,
519        is_cuda: bool,
520        sync_device: &Device,
521        progress: &ProgressReporter,
522        oom_message: &str,
523        is_oom: FOom,
524    ) -> Result<T>
525    where
526        FPrimary: FnOnce() -> Result<T>,
527        FFallback: FnOnce() -> Result<T>,
528        FOom: Fn(&anyhow::Error) -> bool,
529    {
530        match primary() {
531            Ok(value) => Ok(value),
532            Err(err) if is_cuda && is_oom(&err) => {
533                progress.info(oom_message);
534                sync_device.synchronize()?;
535                fallback()
536            }
537            Err(err) => Err(err),
538        }
539    }
540
541    #[allow(clippy::too_many_arguments)]
542    fn with_cuda_tiled_then_cpu_fallback<T, FPrimary, FTiled, FCpu, FOom>(
543        primary: FPrimary,
544        tiled: FTiled,
545        cpu_fallback: FCpu,
546        is_cuda: bool,
547        prefer_tiled: bool,
548        sync_device: &Device,
549        progress: &ProgressReporter,
550        tiled_message: &str,
551        cpu_message: &str,
552        is_oom: FOom,
553    ) -> Result<T>
554    where
555        FPrimary: FnOnce() -> Result<T>,
556        FTiled: FnOnce() -> Result<T>,
557        FCpu: FnOnce() -> Result<T>,
558        FOom: Fn(&anyhow::Error) -> bool,
559    {
560        if is_cuda && prefer_tiled {
561            progress.info("Selecting tiled GPU VAE decode proactively");
562            match tiled() {
563                Ok(value) => return Ok(value),
564                Err(tile_err) if is_oom(&tile_err) => {
565                    progress.info(cpu_message);
566                    sync_device.synchronize()?;
567                    return cpu_fallback();
568                }
569                Err(tile_err) => return Err(tile_err),
570            }
571        }
572
573        match primary() {
574            Ok(value) => Ok(value),
575            Err(err) if is_cuda && is_oom(&err) => {
576                progress.info(tiled_message);
577                sync_device.synchronize()?;
578                match tiled() {
579                    Ok(value) => Ok(value),
580                    Err(tile_err) if is_oom(&tile_err) => {
581                        progress.info(cpu_message);
582                        sync_device.synchronize()?;
583                        cpu_fallback()
584                    }
585                    Err(tile_err) => Err(tile_err),
586                }
587            }
588            Err(err) => Err(err),
589        }
590    }
591
592    fn qwen_vae_decode_workspace_bytes(width: u32, height: u32) -> u64 {
593        let pixels = width as u64 * height as u64;
594        // Qwen's 3D causal VAE decode has a much larger transient workspace
595        // than the final RGB tensor. This factor is intentionally conservative:
596        // native 1328² requests reserve ~7.2 GB, while small 512² requests stay
597        // below the proactive tiling threshold.
598        pixels.saturating_mul(4).saturating_mul(1024)
599    }
600
601    fn should_proactively_tile_vae_decode(
602        width: u32,
603        height: u32,
604        vae_is_cuda: bool,
605        free_vram_bytes: u64,
606    ) -> bool {
607        if !vae_is_cuda || free_vram_bytes == 0 {
608            return false;
609        }
610        let native_pixels = (QWEN_NATIVE_WIDTH * QWEN_NATIVE_HEIGHT) as u64;
611        let pixels = width as u64 * height as u64;
612        if pixels < native_pixels.saturating_mul(3) / 4 {
613            return false;
614        }
615        let required = VAE_DECODE_VRAM_THRESHOLD
616            .saturating_add(Self::qwen_vae_decode_workspace_bytes(width, height));
617        free_vram_bytes < required
618    }
619
620    fn qwen2_text_encoder_post_encode_action(
621        input: Qwen2TextEncoderResidencyInput,
622    ) -> Qwen2TextEncoderPostEncodeAction {
623        if !input.on_gpu {
624            return Qwen2TextEncoderPostEncodeAction::Drop;
625        }
626        if input.prompt_cache_miss
627            && input.transformer_resident
628            && !input.is_metal
629            && input.free_vram_bytes >= input.required_vram_bytes
630        {
631            return Qwen2TextEncoderPostEncodeAction::KeepGpu;
632        }
633        if input.keep_te_ram && !input.is_metal && !input.is_quantized {
634            return Qwen2TextEncoderPostEncodeAction::ParkCpu;
635        }
636        Qwen2TextEncoderPostEncodeAction::Drop
637    }
638
639    fn qwen2_hot_text_encoder_required_vram(
640        width: u32,
641        height: u32,
642        cfg_batch: u32,
643        dtype: DType,
644    ) -> u64 {
645        crate::device::activation_bytes(
646            width,
647            height,
648            cfg_batch,
649            crate::device::dtype_bytes(dtype),
650            crate::device::ActivationFamily::QwenImageDit,
651        )
652        .saturating_add(VAE_DECODE_VRAM_THRESHOLD)
653        .saturating_add(Self::qwen_vae_decode_workspace_bytes(width, height))
654        .saturating_add(QWEN2_HOT_TE_RESIDENCY_HEADROOM)
655    }
656
657    fn decode_vae_tiled(
658        latents: &Tensor,
659        vae: &QwenImageVae,
660        vae_device: &Device,
661        progress: &ProgressReporter,
662    ) -> Result<Tensor> {
663        for tile_size in QWEN_VAE_TILE_SIZES {
664            let overlap = (tile_size / 4).max(4);
665            progress.info(&format!(
666                "Retrying VAE decode with tiled GPU decode (tile {} overlap {})",
667                tile_size, overlap
668            ));
669            let config = TilingConfig {
670                tile_size,
671                overlap,
672                min_tile_size: 16,
673            };
674            let forward = |tile: &Tensor| {
675                let tile = tile.to_device(vae_device)?.to_dtype(DType::F32)?;
676                vae.decode(&tile).map_err(Into::into)
677            };
678            // `upscale_with_tiling` is reused here because Qwen-Image VAE decode
679            // is guaranteed to return 3-channel RGB. If a future VAE family
680            // changes that contract, this call site needs a tiler that handles
681            // arbitrary output channel counts.
682            match upscale_with_tiling(latents, &forward, 8, &config, &Device::Cpu, progress) {
683                Ok(image) => return Ok(image),
684                Err(e) if vae_device.is_cuda() && Self::is_oom_error(&e) => {
685                    if let Err(sync_err) = vae_device.synchronize() {
686                        tracing::warn!(
687                            "failed to synchronize CUDA device after tiled VAE OOM: {sync_err}"
688                        );
689                    }
690                }
691                Err(e) => return Err(e),
692            }
693        }
694
695        bail!("tiled VAE decode still ran out of memory")
696    }
697
698    fn decode_vae_with_fallback<F>(
699        latents: &Tensor,
700        vae: &QwenImageVae,
701        vae_device: &Device,
702        sync_device: &Device,
703        progress: &ProgressReporter,
704        prefer_tiled: bool,
705        load_cpu_vae: F,
706    ) -> Result<Tensor>
707    where
708        F: FnOnce() -> Result<QwenImageVae>,
709    {
710        let decode_latents = latents.to_device(vae_device)?.to_dtype(DType::F32)?;
711        Self::debug_tensor_stats("latents_pre_vae", &decode_latents);
712        Self::with_cuda_tiled_then_cpu_fallback(
713            || vae.decode(&decode_latents).map_err(Into::into),
714            || Self::decode_vae_tiled(latents, vae, vae_device, progress),
715            || {
716                let cpu_vae = load_cpu_vae()?;
717                let cpu_latents = latents.to_device(&Device::Cpu)?.to_dtype(DType::F32)?;
718                cpu_vae.decode(&cpu_latents).map_err(Into::into)
719            },
720            vae_device.is_cuda(),
721            prefer_tiled,
722            sync_device,
723            progress,
724            "VAE decode OOM on GPU — retrying with tiled GPU decode",
725            "Tiled GPU VAE decode OOM — retrying on CPU",
726            Self::is_oom_error,
727        )
728    }
729
730    /// Encode a source image through the VAE with GPU→CPU OOM fallback.
731    #[allow(clippy::too_many_arguments)]
732    fn encode_vae_with_fallback(
733        source_bytes: &[u8],
734        width: u32,
735        height: u32,
736        vae: &QwenImageVae,
737        vae_device: &Device,
738        sync_device: &Device,
739        progress: &ProgressReporter,
740        load_cpu_vae: impl FnOnce() -> Result<QwenImageVae>,
741    ) -> Result<Tensor> {
742        progress.stage_start("Encoding source image (VAE)");
743        let encode_start = Instant::now();
744
745        // Qwen-Image VAE expects [-1, 1] normalized pixels
746        let source_tensor = img_utils::decode_source_image(
747            source_bytes,
748            width,
749            height,
750            Self::img2img_source_normalize_range(),
751            vae_device,
752            DType::F32,
753        )?;
754
755        let result = Self::with_cuda_oom_cpu_fallback(
756            || vae.encode(&source_tensor).map_err(Into::into),
757            || {
758                let cpu_vae = load_cpu_vae()?;
759                let cpu_source = img_utils::decode_source_image(
760                    source_bytes,
761                    width,
762                    height,
763                    Self::img2img_source_normalize_range(),
764                    &Device::Cpu,
765                    DType::F32,
766                )?;
767                cpu_vae.encode(&cpu_source).map_err(Into::into)
768            },
769            vae_device.is_cuda(),
770            sync_device,
771            progress,
772            "VAE encode OOM on GPU — retrying on CPU",
773            Self::is_oom_error,
774        );
775
776        progress.stage_done("Encoding source image (VAE)", encode_start.elapsed());
777        result
778    }
779
780    fn choose_text_encoder_source(
781        preference: Option<&str>,
782        is_cuda: bool,
783        is_metal: bool,
784        free_vram: u64,
785        bf16_size_bytes: u64,
786        _usage: Qwen2TextEncoderUsage,
787    ) -> Result<ResolvedQwen2TextEncoder> {
788        match preference {
789            Some(tag) if tag != "auto" && tag != "bf16" => {
790                let variant = mold_core::manifest::find_qwen2_vl_variant(tag).ok_or_else(|| {
791                    anyhow::anyhow!(
792                        "unknown Qwen2.5-VL variant '{}'. Valid: bf16, auto, q8, q6, q5, q4, q3, q2",
793                        tag
794                    )
795                })?;
796                Ok(ResolvedQwen2TextEncoder {
797                    paths: vec![],
798                    vision_paths: vec![],
799                    is_gguf: true,
800                    variant_label: variant.tag.to_string(),
801                    size_bytes: variant.size_bytes,
802                    auto_use_gpu: should_use_gpu(
803                        is_cuda,
804                        is_metal,
805                        free_vram,
806                        qwen2_vram_threshold(variant.size_bytes),
807                    ),
808                })
809            }
810            Some("bf16") => Ok(ResolvedQwen2TextEncoder {
811                paths: vec![],
812                vision_paths: vec![],
813                is_gguf: false,
814                variant_label: "bf16".to_string(),
815                size_bytes: bf16_size_bytes,
816                auto_use_gpu: should_use_gpu(
817                    is_cuda,
818                    is_metal,
819                    free_vram,
820                    QWEN2_FP16_VRAM_THRESHOLD,
821                ),
822            }),
823            _ if is_metal => {
824                for tag in ["q6", "q4"] {
825                    let variant = mold_core::manifest::find_qwen2_vl_variant(tag)
826                        .expect("known Metal auto qwen2 variant missing");
827                    if fits_in_memory(
828                        is_cuda,
829                        is_metal,
830                        free_vram,
831                        qwen2_vram_threshold(variant.size_bytes),
832                    ) {
833                        return Ok(ResolvedQwen2TextEncoder {
834                            paths: vec![],
835                            vision_paths: vec![],
836                            is_gguf: true,
837                            variant_label: variant.tag.to_string(),
838                            size_bytes: variant.size_bytes,
839                            auto_use_gpu: true,
840                        });
841                    }
842                }
843                let fallback = mold_core::manifest::find_qwen2_vl_variant("q4")
844                    .expect("known Metal fallback qwen2 variant missing");
845                Ok(ResolvedQwen2TextEncoder {
846                    paths: vec![],
847                    vision_paths: vec![],
848                    is_gguf: true,
849                    variant_label: fallback.tag.to_string(),
850                    size_bytes: fallback.size_bytes,
851                    auto_use_gpu: true,
852                })
853            }
854            _ => {
855                let bf16_on_gpu =
856                    should_use_gpu(is_cuda, is_metal, free_vram, QWEN2_FP16_VRAM_THRESHOLD);
857                if bf16_on_gpu {
858                    return Ok(ResolvedQwen2TextEncoder {
859                        paths: vec![],
860                        vision_paths: vec![],
861                        is_gguf: false,
862                        variant_label: "bf16".to_string(),
863                        size_bytes: bf16_size_bytes,
864                        auto_use_gpu: true,
865                    });
866                }
867
868                if is_cuda {
869                    let fallback_tag = "q4";
870                    let fallback = mold_core::manifest::find_qwen2_vl_variant(fallback_tag)
871                        .expect("known CUDA fallback qwen2 variant missing");
872                    return Ok(ResolvedQwen2TextEncoder {
873                        paths: vec![],
874                        vision_paths: vec![],
875                        is_gguf: true,
876                        variant_label: fallback.tag.to_string(),
877                        size_bytes: fallback.size_bytes,
878                        auto_use_gpu: fits_in_memory(
879                            is_cuda,
880                            is_metal,
881                            free_vram,
882                            qwen2_vram_threshold(fallback.size_bytes),
883                        ),
884                    });
885                }
886
887                Ok(ResolvedQwen2TextEncoder {
888                    paths: vec![],
889                    vision_paths: vec![],
890                    is_gguf: false,
891                    variant_label: "bf16".to_string(),
892                    size_bytes: bf16_size_bytes,
893                    auto_use_gpu: false,
894                })
895            }
896        }
897    }
898
899    fn tensor_stats(tensor: &Tensor) -> Result<QwenTensorStats> {
900        let t = tensor.to_dtype(DType::F32)?;
901        let values = t.flatten_all()?.to_vec1::<f32>()?;
902        let mut min = f32::INFINITY;
903        let mut max = f32::NEG_INFINITY;
904        let mut sum = 0.0f64;
905        let mut finite_count = 0usize;
906        let mut nan_count = 0u64;
907        let mut pos_inf_count = 0u64;
908        let mut neg_inf_count = 0u64;
909        for value in &values {
910            if value.is_nan() {
911                nan_count += 1;
912            } else if *value == f32::INFINITY {
913                pos_inf_count += 1;
914            } else if *value == f32::NEG_INFINITY {
915                neg_inf_count += 1;
916            } else {
917                min = min.min(*value);
918                max = max.max(*value);
919                sum += *value as f64;
920                finite_count += 1;
921            }
922        }
923        let mean = if finite_count == 0 {
924            f32::NAN
925        } else {
926            (sum / finite_count as f64) as f32
927        };
928        if finite_count == 0 {
929            min = f32::NAN;
930            max = f32::NAN;
931        }
932        Ok(QwenTensorStats {
933            min,
934            max,
935            mean,
936            nan_count,
937            pos_inf_count,
938            neg_inf_count,
939            total: values.len(),
940        })
941    }
942
943    fn format_tensor_stats(name: &str, stats: QwenTensorStats) -> String {
944        format!(
945            "[qwen-debug] {name}: min={:.4} max={:.4} mean={:.4} NaN={}/{} ({:.1}%) +Inf={} -Inf={}",
946            stats.min,
947            stats.max,
948            stats.mean,
949            stats.nan_count,
950            stats.total,
951            stats.nan_count as f64 / stats.total.max(1) as f64 * 100.0,
952            stats.pos_inf_count,
953            stats.neg_inf_count
954        )
955    }
956
957    fn near_black_image_stats(stats: QwenTensorStats) -> bool {
958        if stats.nan_count > 0
959            || stats.pos_inf_count > 0
960            || stats.neg_inf_count > 0
961            || !stats.min.is_finite()
962            || !stats.max.is_finite()
963            || !stats.mean.is_finite()
964        {
965            return false;
966        }
967        let scale = if stats.max <= 1.0 { 1.0 } else { 255.0 };
968        stats.max <= 0.02 * scale && stats.mean <= 0.01 * scale
969    }
970
971    fn validate_qwen_tensor_boundary(name: &str, tensor: &Tensor) -> Result<QwenTensorStats> {
972        let stats = Self::tensor_stats(tensor)?;
973        if stats.nan_count > 0
974            || stats.pos_inf_count > 0
975            || stats.neg_inf_count > 0
976            || !stats.min.is_finite()
977            || !stats.max.is_finite()
978            || !stats.mean.is_finite()
979        {
980            bail!(
981                "Qwen diagnostic boundary '{name}' contains non-finite values: {}",
982                Self::format_tensor_stats(name, stats)
983            );
984        }
985        Ok(stats)
986    }
987
988    fn debug_tensor_stats(name: &str, tensor: &Tensor) {
989        if std::env::var_os("MOLD_QWEN_DEBUG").is_none() {
990            return;
991        }
992        match Self::tensor_stats(tensor) {
993            Ok(stats) => eprintln!("{}", Self::format_tensor_stats(name, stats)),
994            Err(err) => eprintln!("[qwen-debug] {name}: <failed: {err}>"),
995        }
996    }
997
998    pub fn new(
999        model_name: String,
1000        paths: ModelPaths,
1001        load_strategy: LoadStrategy,
1002        gpu_ordinal: usize,
1003        offload: bool,
1004        shared_pool: Option<Arc<Mutex<crate::shared_pool::SharedPool>>>,
1005    ) -> Self {
1006        Self {
1007            base: EngineBase::new(model_name, paths, load_strategy, gpu_ordinal),
1008            prompt_cache: Mutex::new(LruCache::new(DEFAULT_PROMPT_CACHE_CAPACITY)),
1009            offload,
1010            pending_placement: None,
1011            pending_loras: Vec::new(),
1012            active_lora_fingerprint: Vec::new(),
1013            shared_pool,
1014        }
1015    }
1016
1017    fn load_text_tokenizer(&self, tokenizer_path: &Path) -> Result<Arc<Tokenizer>> {
1018        if let Some(shared_pool) = &self.shared_pool {
1019            return shared_pool.lock().unwrap().load_tokenizer(tokenizer_path);
1020        }
1021        Tokenizer::from_file(tokenizer_path)
1022            .map(Arc::new)
1023            .map_err(|e| anyhow::anyhow!("failed to load Qwen2.5 tokenizer: {e}"))
1024    }
1025
1026    fn encode_prompt_cached(
1027        progress: &ProgressReporter,
1028        prompt_cache: &Mutex<LruCache<String, CachedPromptConditioning>>,
1029        text_encoder: &mut encoders::qwen2_text::Qwen2TextEncoder,
1030        prompt: &str,
1031        device: &Device,
1032        dtype: DType,
1033    ) -> Result<(Tensor, Tensor)> {
1034        let cache_key = prompt_text_key(prompt);
1035        if let Some(cached) = prompt_cache
1036            .lock()
1037            .expect("cache poisoned")
1038            .get_cloned(&cache_key)
1039        {
1040            progress.cache_hit("prompt conditioning");
1041            return cached.restore(device, dtype);
1042        }
1043
1044        progress.stage_start("Encoding prompt (Qwen2.5)");
1045        let encode_start = Instant::now();
1046        let (hidden_states, _attention_mask, valid_len) =
1047            text_encoder.encode(prompt, device, dtype)?;
1048        progress.stage_done("Encoding prompt (Qwen2.5)", encode_start.elapsed());
1049
1050        prompt_cache.lock().expect("cache poisoned").insert(
1051            cache_key,
1052            CachedPromptConditioning::from_parts(&hidden_states, valid_len)?,
1053        );
1054
1055        let mut mask = vec![0u8; hidden_states.dim(1)?];
1056        for value in &mut mask[..valid_len] {
1057            *value = 1;
1058        }
1059        let attention_mask = Tensor::from_vec(mask, (1, hidden_states.dim(1)?), device)?;
1060        Ok((hidden_states, attention_mask))
1061    }
1062
1063    fn spill_conditioning_to_cpu(
1064        hidden_states: Tensor,
1065        attention_mask: Tensor,
1066    ) -> Result<(Tensor, Tensor)> {
1067        Ok((
1068            hidden_states
1069                .to_device(&Device::Cpu)?
1070                .to_dtype(DType::F32)?,
1071            attention_mask.to_device(&Device::Cpu)?,
1072        ))
1073    }
1074
1075    fn maybe_spill_conditioning(
1076        use_cpu_staging: bool,
1077        hidden_states: Tensor,
1078        attention_mask: Tensor,
1079    ) -> Result<(Tensor, Tensor)> {
1080        if use_cpu_staging {
1081            Self::spill_conditioning_to_cpu(hidden_states, attention_mask)
1082        } else {
1083            Ok((hidden_states, attention_mask))
1084        }
1085    }
1086
1087    /// Resolve transformer shard paths.
1088    fn transformer_paths(&self) -> Vec<std::path::PathBuf> {
1089        if !self.base.paths.transformer_shards.is_empty() {
1090            self.base.paths.transformer_shards.clone()
1091        } else {
1092            vec![self.base.paths.transformer.clone()]
1093        }
1094    }
1095
1096    fn detect_is_quantized(&self) -> bool {
1097        self.base
1098            .paths
1099            .transformer
1100            .extension()
1101            .and_then(|e| e.to_str())
1102            .map(|e| e.eq_ignore_ascii_case("gguf"))
1103            .unwrap_or(false)
1104    }
1105
1106    /// Validate required paths exist.
1107    fn validate_paths(&self) -> Result<std::path::PathBuf> {
1108        let text_tokenizer_path =
1109            self.base.paths.text_tokenizer.as_ref().ok_or_else(|| {
1110                anyhow::anyhow!("text tokenizer path required for Qwen-Image models")
1111            })?;
1112        if !text_tokenizer_path.exists() {
1113            bail!(
1114                "text tokenizer file not found: {}",
1115                text_tokenizer_path.display()
1116            );
1117        }
1118
1119        let xformer_paths = self.transformer_paths();
1120        for path in &xformer_paths {
1121            if !path.exists() {
1122                bail!("transformer file not found: {}", path.display());
1123            }
1124        }
1125        if !self.base.paths.vae.exists() {
1126            bail!("VAE file not found: {}", self.base.paths.vae.display());
1127        }
1128
1129        Ok(text_tokenizer_path.clone())
1130    }
1131
1132    fn quantized_cuda_cfg_headroom(width: usize, height: usize) -> u64 {
1133        let native_pixels = (QWEN_NATIVE_WIDTH * QWEN_NATIVE_HEIGHT) as f64;
1134        let pixels = (width.max(1) * height.max(1)) as f64;
1135        let scaled =
1136            (QWEN_GGUF_NATIVE_CFG_HEADROOM as f64 * (pixels / native_pixels)).round() as u64;
1137        scaled.max(QWEN_GGUF_MIN_CFG_HEADROOM)
1138    }
1139
1140    fn should_split_cfg_quantized_cuda(
1141        transformer_size: u64,
1142        free_vram: u64,
1143        width: usize,
1144        height: usize,
1145    ) -> bool {
1146        if free_vram == 0 {
1147            // If VRAM probing fails, bias toward the safer split-CFG path
1148            // instead of assuming batched CFG will fit.
1149            return true;
1150        }
1151        let estimated_peak =
1152            transformer_size.saturating_add(Self::quantized_cuda_cfg_headroom(width, height));
1153        estimated_peak > free_vram
1154    }
1155
1156    /// Load transformer from disk.
1157    fn load_transformer(
1158        &self,
1159        device: &Device,
1160        dtype: DType,
1161        cfg: &QwenImageConfig,
1162        width: usize,
1163        height: usize,
1164    ) -> Result<QwenImageTransformer> {
1165        let active_loras = &self.pending_loras;
1166        let has_lora = !active_loras.is_empty();
1167        if self.detect_is_quantized() {
1168            let transformer_size = std::fs::metadata(&self.base.paths.transformer)
1169                .map(|m| m.len())
1170                .unwrap_or(0);
1171            // Reserve-adjusted reading: split-CFG is a budget decision.
1172            let free = usable_free_vram_bytes(self.base.gpu_ordinal).unwrap_or(0);
1173            let split_cfg_for_memory = device.is_cuda()
1174                && (self.offload
1175                    || Self::should_split_cfg_quantized_cuda(
1176                        transformer_size,
1177                        free,
1178                        width,
1179                        height,
1180                    ));
1181            if self.offload && device.is_cuda() {
1182                self.base.progress.info(
1183                    "Quantized Qwen CUDA offload requested — using low-memory split-CFG mode until GGUF block offload lands",
1184                );
1185            } else if split_cfg_for_memory {
1186                let estimated_peak = transformer_size
1187                    .saturating_add(Self::quantized_cuda_cfg_headroom(width, height));
1188                self.base.progress.info(&format!(
1189                    "Using low-memory quantized Qwen CUDA path (est. peak {}, {} free at {}x{})",
1190                    fmt_gb(estimated_peak),
1191                    fmt_gb(free),
1192                    width,
1193                    height,
1194                ));
1195            }
1196            let vb = if has_lora {
1197                let adapters = super::lora::load_lora_adapters(active_loras, &self.base.progress)?;
1198                let specs: Vec<super::lora::QwenImageLoraSpec<'_>> = adapters
1199                    .iter()
1200                    .zip(active_loras.iter())
1201                    .map(|(adapter, w)| super::lora::QwenImageLoraSpec {
1202                        adapter: adapter.as_ref(),
1203                        scale: w.scale,
1204                        path_hash: super::lora::lora_path_hash(&w.path),
1205                    })
1206                    .collect();
1207                super::lora::gguf_lora_var_builder(
1208                    &self.base.paths.transformer,
1209                    &specs,
1210                    device,
1211                    &self.base.progress,
1212                    None,
1213                )?
1214            } else {
1215                quantized_var_builder::VarBuilder::from_gguf(&self.base.paths.transformer, device)?
1216            };
1217            Ok(QwenImageTransformer::Quantized(
1218                QuantizedQwenImageTransformer2DModel::new(cfg, vb, device, !split_cfg_for_memory)?,
1219            ))
1220        } else {
1221            let xformer_paths = self.transformer_paths();
1222            let is_fp8 = xformer_paths
1223                .first()
1224                .map(|p| safetensors_is_fp8(p))
1225                .unwrap_or(false);
1226
1227            // FP8 weights stay as F8E4M3 in VRAM (~19.5GB, 1 byte/param).
1228            // Per-layer dequant to BF16 during forward adds ~113MB transient.
1229            // BF16 weights are 2 bytes/param (~40GB).
1230            let mem_size: u64 = xformer_paths
1231                .iter()
1232                .filter_map(|p| std::fs::metadata(p).ok())
1233                .map(|m| m.len())
1234                .sum();
1235            // Reserve-adjusted reading: should_offload budgets against this.
1236            let free = usable_free_vram_bytes(self.base.gpu_ordinal).unwrap_or(0);
1237            // Qwen-Image runs CFG by default; activation budget scales with
1238            // resolution to replace the previous fixed 3 GB heuristic.
1239            let activation_budget = crate::device::activation_bytes(
1240                width as u32,
1241                height as u32,
1242                2,
1243                crate::device::dtype_bytes(dtype),
1244                crate::device::ActivationFamily::QwenImageDit,
1245            );
1246            let use_offload =
1247                self.offload || crate::device::should_offload(mem_size, free, activation_budget);
1248
1249            if is_fp8 {
1250                self.base
1251                    .progress
1252                    .info("Detected FP8 safetensors — loading with scale dequantization");
1253            }
1254
1255            if use_offload {
1256                if has_lora {
1257                    bail!(
1258                        "Qwen-Image LoRA support is not yet wired through the block-offload \
1259                         transformer path. Disable offload (drop --offload / unset MOLD_OFFLOAD), \
1260                         or pick a checkpoint that fits without offload, to use LoRAs."
1261                    );
1262                }
1263                // Create TWO VarBuilders: GPU for blocks that fit, CPU for overflow.
1264                let (gpu_vb, cpu_vb) = if is_fp8 {
1265                    let gpu = crate::weight_loader::load_fp8_safetensors(
1266                        &xformer_paths,
1267                        device,
1268                        "Qwen-Image transformer (offload, GPU)",
1269                        &self.base.progress,
1270                    )?;
1271                    let cpu = crate::weight_loader::load_fp8_safetensors(
1272                        &xformer_paths,
1273                        &Device::Cpu,
1274                        "Qwen-Image transformer (offload, CPU)",
1275                        &self.base.progress,
1276                    )?;
1277                    (gpu, cpu)
1278                } else {
1279                    let gpu = crate::weight_loader::load_safetensors_with_progress(
1280                        &xformer_paths,
1281                        dtype,
1282                        device,
1283                        "Qwen-Image transformer (offload, GPU)",
1284                        &self.base.progress,
1285                    )?;
1286                    let cpu = unsafe {
1287                        candle_nn::VarBuilder::from_mmaped_safetensors(
1288                            &xformer_paths
1289                                .iter()
1290                                .map(|p| p.as_path())
1291                                .collect::<Vec<_>>(),
1292                            DType::BF16,
1293                            &Device::Cpu,
1294                        )?
1295                    };
1296                    (gpu, cpu)
1297                };
1298                Ok(QwenImageTransformer::Offloaded(
1299                    super::offload::OffloadedQwenImageTransformer::load(
1300                        gpu_vb,
1301                        cpu_vb,
1302                        cfg,
1303                        device,
1304                        self.base.gpu_ordinal,
1305                        activation_budget,
1306                        &self.base.progress,
1307                    )?,
1308                ))
1309            } else {
1310                let xformer_vb = if has_lora {
1311                    self.build_bf16_lora_var_builder(
1312                        &xformer_paths,
1313                        dtype,
1314                        device,
1315                        is_fp8,
1316                        active_loras,
1317                    )?
1318                } else if is_fp8 {
1319                    crate::weight_loader::load_fp8_safetensors(
1320                        &xformer_paths,
1321                        device,
1322                        "Qwen-Image transformer",
1323                        &self.base.progress,
1324                    )?
1325                } else {
1326                    crate::weight_loader::load_safetensors_with_progress(
1327                        &xformer_paths,
1328                        dtype,
1329                        device,
1330                        "Qwen-Image transformer",
1331                        &self.base.progress,
1332                    )?
1333                };
1334                Ok(QwenImageTransformer::BF16(
1335                    QwenImageTransformer2DModel::new(cfg, xformer_vb)?,
1336                ))
1337            }
1338        }
1339    }
1340
1341    /// Construct a `VarBuilder` for the BF16/FP8 in-memory path with a
1342    /// LoRA-merging `SimpleBackend` wrapping the underlying mmap (or
1343    /// `NativeFp8Backend`). Each `vb.get()` call delivers a tensor with
1344    /// `W' = W + scale·(B @ A)` already merged in.
1345    fn build_bf16_lora_var_builder<'a>(
1346        &self,
1347        xformer_paths: &[std::path::PathBuf],
1348        dtype: DType,
1349        device: &Device,
1350        is_fp8: bool,
1351        loras: &[mold_core::LoraWeight],
1352    ) -> Result<candle_nn::VarBuilder<'a>> {
1353        let adapters = super::lora::load_lora_adapters(loras, &self.base.progress)?;
1354        let specs: Vec<super::lora::QwenImageLoraSpec<'_>> = adapters
1355            .iter()
1356            .zip(loras.iter())
1357            .map(|(adapter, w)| super::lora::QwenImageLoraSpec {
1358                adapter: adapter.as_ref(),
1359                scale: w.scale,
1360                path_hash: super::lora::lora_path_hash(&w.path),
1361            })
1362            .collect();
1363
1364        let path_refs: Vec<&std::path::Path> = xformer_paths.iter().map(|p| p.as_path()).collect();
1365        let tensors = unsafe { candle_core::safetensors::MmapedSafetensors::multi(&path_refs)? };
1366        let inner: Box<dyn candle_nn::var_builder::SimpleBackend> = if is_fp8 {
1367            // FP8 path needs the `NativeFp8Backend` so F8E4M3 weights
1368            // stay F8E4M3 in VRAM; the LoRA wrapper merges deltas in
1369            // F32 and the per-layer dequant in `QwenLinear::Fp8::forward`
1370            // sees pre-merged weights as expected.
1371            self.base
1372                .progress
1373                .info("Detected FP8 safetensors — loading with LoRA-merging wrapper");
1374            Box::new(crate::weight_loader::NativeFp8Backend::from_mmap(tensors))
1375        } else {
1376            // candle's `MmapedSafetensors` implements `SimpleBackend`
1377            // directly; use it as the inner layer of the LoRA wrapper.
1378            Box::new(tensors)
1379        };
1380
1381        let wrapped =
1382            super::lora::wrap_backend_with_lora(inner, &specs, &self.base.progress, None)?;
1383
1384        let target_dtype = if is_fp8 { DType::BF16 } else { dtype };
1385        Ok(candle_nn::VarBuilder::from_backend(
1386            wrapped,
1387            target_dtype,
1388            device.clone(),
1389        ))
1390    }
1391
1392    /// Load VAE from disk.
1393    fn load_vae(&self, device: &Device, dtype: DType) -> Result<QwenImageVae> {
1394        let vb = self.load_vae_var_builder(device, dtype)?;
1395        Ok(QwenImageVae::from_var_builder(vb, device, dtype)?)
1396    }
1397
1398    fn load_vae_cpu_tensors(&self) -> Result<Option<Arc<HashMap<String, Tensor>>>> {
1399        let Some(shared_pool) = &self.shared_pool else {
1400            return Ok(None);
1401        };
1402        shared_pool
1403            .lock()
1404            .unwrap()
1405            .load_safetensors_cpu_tensors(std::slice::from_ref(&self.base.paths.vae))
1406    }
1407
1408    fn load_vae_var_builder<'a>(
1409        &self,
1410        device: &Device,
1411        dtype: DType,
1412    ) -> Result<candle_nn::VarBuilder<'a>> {
1413        if let Some(tensors) = self.load_vae_cpu_tensors()? {
1414            return Ok(encoders::park::varbuilder_from_parked(
1415                tensors.as_ref(),
1416                dtype,
1417                device,
1418            ));
1419        }
1420
1421        crate::weight_loader::load_safetensors_with_progress(
1422            std::slice::from_ref(&self.base.paths.vae),
1423            dtype,
1424            device,
1425            "Qwen-Image VAE",
1426            &self.base.progress,
1427        )
1428    }
1429
1430    /// Load text encoder from disk.
1431    ///
1432    /// FP8 text encoders are loaded on GPU with BF16 dtype — candle's CUDA cast
1433    /// kernel handles F8E4M3→BF16 conversion during tensor loading.
1434    fn resolve_text_encoder_source(
1435        &self,
1436        gpu_device: &Device,
1437        free_vram: u64,
1438        usage: Qwen2TextEncoderUsage,
1439    ) -> Result<ResolvedQwen2TextEncoder> {
1440        let preference = std::env::var("MOLD_QWEN2_VARIANT").ok();
1441        self.resolve_text_encoder_source_with_preference(
1442            gpu_device,
1443            free_vram,
1444            usage,
1445            preference.as_deref(),
1446        )
1447    }
1448
1449    fn resolve_text_encoder_source_with_preference(
1450        &self,
1451        gpu_device: &Device,
1452        free_vram: u64,
1453        usage: Qwen2TextEncoderUsage,
1454        preference: Option<&str>,
1455    ) -> Result<ResolvedQwen2TextEncoder> {
1456        let is_cuda = gpu_device.is_cuda();
1457        let is_metal = gpu_device.is_metal();
1458        let bf16_size_bytes = self
1459            .base
1460            .paths
1461            .text_encoder_files
1462            .iter()
1463            .filter_map(|p| std::fs::metadata(p).ok())
1464            .map(|m| m.len())
1465            .sum();
1466        if self.is_edit_family() {
1467            let mut resolved = Self::choose_text_encoder_source(
1468                preference,
1469                is_cuda,
1470                is_metal,
1471                free_vram,
1472                bf16_size_bytes,
1473                Qwen2TextEncoderUsage::Resident,
1474            )?;
1475            resolved.vision_paths = self.base.paths.text_encoder_files.clone();
1476            if resolved.is_gguf {
1477                let variant = mold_core::manifest::find_qwen2_vl_variant(&resolved.variant_label)
1478                    .ok_or_else(|| {
1479                    anyhow::anyhow!("unknown Qwen2.5-VL variant '{}'", resolved.variant_label)
1480                })?;
1481                resolved.paths = vec![
1482                    crate::encoders::variant_resolution::resolve_qwen2_vl_gguf_path(
1483                        &self.base.progress,
1484                        variant,
1485                    )?,
1486                ];
1487            } else {
1488                resolved.paths = self.base.paths.text_encoder_files.clone();
1489            }
1490            return Ok(resolved);
1491        }
1492        let mut resolved = Self::choose_text_encoder_source(
1493            preference,
1494            is_cuda,
1495            is_metal,
1496            free_vram,
1497            bf16_size_bytes,
1498            usage,
1499        )?;
1500
1501        if resolved.is_gguf {
1502            let variant = mold_core::manifest::find_qwen2_vl_variant(&resolved.variant_label)
1503                .ok_or_else(|| {
1504                    anyhow::anyhow!("unknown Qwen2.5-VL variant '{}'", resolved.variant_label)
1505                })?;
1506            resolved.paths = vec![
1507                crate::encoders::variant_resolution::resolve_qwen2_vl_gguf_path(
1508                    &self.base.progress,
1509                    variant,
1510                )?,
1511            ];
1512        } else {
1513            resolved.paths = self.base.paths.text_encoder_files.clone();
1514        }
1515        resolved.vision_paths = vec![];
1516
1517        match preference {
1518            Some(tag) if tag != "auto" && tag != "bf16" => self.base.progress.info(&format!(
1519                "Using quantized Qwen2.5-VL {} ({}) on {} (explicit)",
1520                resolved.variant_label,
1521                fmt_gb(resolved.size_bytes),
1522                if resolved.auto_use_gpu { "GPU" } else { "CPU" },
1523            )),
1524            Some("bf16") => {}
1525            _ if is_metal && resolved.is_gguf && resolved.variant_label == "q6" => self
1526                .base
1527                .progress
1528                .info(&format!(
1529                    "Metal auto mode selected quantized Qwen2.5-VL {} ({}) for lower memory pressure",
1530                    resolved.variant_label,
1531                    fmt_gb(resolved.size_bytes),
1532                )),
1533            _ if is_metal && resolved.is_gguf => self.base.progress.info(&format!(
1534                "Metal auto mode forcing quantized Qwen2.5-VL {} ({}) to avoid BF16 memory pressure",
1535                resolved.variant_label,
1536                fmt_gb(resolved.size_bytes),
1537            )),
1538            _ if is_cuda && resolved.is_gguf && resolved.auto_use_gpu => self.base.progress.info(
1539                &format!(
1540                    "CUDA auto mode selected quantized Qwen2.5-VL {} ({}) on GPU",
1541                    resolved.variant_label,
1542                    fmt_gb(resolved.size_bytes),
1543                ),
1544            ),
1545            _ if is_cuda && resolved.is_gguf => self.base.progress.info(&format!(
1546                "CUDA auto mode selected quantized Qwen2.5-VL {} ({}) on CPU to avoid large BF16 host residency",
1547                resolved.variant_label,
1548                fmt_gb(resolved.size_bytes),
1549            )),
1550            _ => {}
1551        }
1552
1553        Ok(resolved)
1554    }
1555
1556    fn can_keep_transformer_hot_for_vae(loaded: &LoadedQwenImage) -> bool {
1557        Self::qwen_transformer_can_stay_hot_for_vae(
1558            loaded.device.is_cuda(),
1559            loaded.vae_device.is_cuda(),
1560            matches!(
1561                loaded.transformer.as_ref(),
1562                Some(QwenImageTransformer::Quantized(_))
1563            ),
1564        )
1565    }
1566
1567    fn qwen_transformer_can_stay_hot_for_vae(
1568        transformer_is_cuda: bool,
1569        vae_is_cuda: bool,
1570        transformer_is_quantized: bool,
1571    ) -> bool {
1572        transformer_is_cuda && vae_is_cuda && transformer_is_quantized
1573    }
1574
1575    fn decode_vae_gpu_only(
1576        latents: &Tensor,
1577        vae: &QwenImageVae,
1578        vae_device: &Device,
1579        sync_device: &Device,
1580        progress: &ProgressReporter,
1581        prefer_tiled: bool,
1582    ) -> Result<Tensor> {
1583        if vae_device.is_cuda() && prefer_tiled {
1584            progress.info("Selecting tiled GPU VAE decode proactively");
1585            return Self::decode_vae_tiled(latents, vae, vae_device, progress);
1586        }
1587
1588        let decode_latents = latents.to_device(vae_device)?.to_dtype(DType::F32)?;
1589        match vae.decode(&decode_latents) {
1590            Ok(image) => Ok(image),
1591            Err(e) if vae_device.is_cuda() && Self::is_oom_error(&e) => {
1592                progress.info(
1593                    "Resident-transformer VAE decode OOM on GPU — retrying with tiled GPU decode before dropping transformer",
1594                );
1595                sync_device.synchronize()?;
1596                Self::decode_vae_tiled(latents, vae, vae_device, progress)
1597            }
1598            Err(e) => Err(e.into()),
1599        }
1600    }
1601
1602    fn load_text_encoder(
1603        &self,
1604        resolved: &ResolvedQwen2TextEncoder,
1605        tokenizer_path: &std::path::PathBuf,
1606        tokenizer: Arc<Tokenizer>,
1607        device: &Device,
1608        dtype: DType,
1609        preload_weights: bool,
1610    ) -> Result<encoders::qwen2_text::Qwen2TextEncoder> {
1611        if resolved.is_gguf {
1612            if preload_weights {
1613                encoders::qwen2_text::Qwen2TextEncoder::load_gguf_with_tokenizer(
1614                    &resolved.paths[0],
1615                    tokenizer_path,
1616                    Some(tokenizer),
1617                    device,
1618                    dtype,
1619                    &resolved.vision_paths,
1620                    &self.base.progress,
1621                )
1622            } else {
1623                encoders::qwen2_text::Qwen2TextEncoder::prepare_gguf_with_tokenizer(
1624                    &resolved.paths[0],
1625                    tokenizer_path,
1626                    Some(tokenizer),
1627                    device,
1628                    dtype,
1629                    &resolved.vision_paths,
1630                )
1631            }
1632        } else {
1633            let is_fp8 = text_encoder_is_fp8(&resolved.paths);
1634            if is_fp8 {
1635                self.base
1636                    .progress
1637                    .info("Detected FP8 text encoder — loading as BF16 on GPU");
1638            }
1639            if preload_weights {
1640                encoders::qwen2_text::Qwen2TextEncoder::load_bf16_with_tokenizer(
1641                    &resolved.paths,
1642                    tokenizer_path,
1643                    Some(tokenizer),
1644                    device,
1645                    dtype,
1646                    self.is_edit_family(),
1647                    &self.base.progress,
1648                )
1649            } else {
1650                encoders::qwen2_text::Qwen2TextEncoder::prepare_bf16_with_tokenizer(
1651                    &resolved.paths,
1652                    tokenizer_path,
1653                    Some(tokenizer),
1654                    device,
1655                    dtype,
1656                    self.is_edit_family(),
1657                )
1658            }
1659        }
1660    }
1661
1662    /// Resolve text encoder device placement and optional CPU staging.
1663    fn resolve_text_encoder_plan(
1664        &self,
1665        gpu_device: &Device,
1666        resolved: &ResolvedQwen2TextEncoder,
1667        free_vram: u64,
1668    ) -> (Qwen2TextEncoderPlan, String) {
1669        let is_cuda = gpu_device.is_cuda();
1670        let is_metal = gpu_device.is_metal();
1671        let plan = Self::qwen2_text_encoder_plan_for_mode(
1672            Qwen2TextEncoderMode::from_env(),
1673            is_cuda,
1674            is_metal,
1675            resolved,
1676        );
1677        let label = if plan.use_gpu { "GPU" } else { "CPU" };
1678        if plan.use_cpu_staging {
1679            self.base
1680                .progress
1681                .info("Qwen2.5 text encoder on GPU with CPU staging after encoding");
1682        } else if !plan.use_gpu {
1683            if resolved.is_gguf {
1684                self.base.progress.info(&format!(
1685                    "Qwen2.5 text encoder on CPU ({} variant {}, {} free)",
1686                    resolved.variant_label,
1687                    fmt_gb(resolved.size_bytes),
1688                    fmt_gb(free_vram),
1689                ));
1690            } else if is_metal || is_cuda {
1691                self.base.progress.info(&format!(
1692                    "Qwen2.5 text encoder on CPU ({} free < {} threshold)",
1693                    fmt_gb(free_vram),
1694                    fmt_gb(QWEN2_FP16_VRAM_THRESHOLD),
1695                ));
1696            }
1697        }
1698        (plan, label.to_string())
1699    }
1700
1701    fn qwen2_text_encoder_plan_for_mode(
1702        mode: Qwen2TextEncoderMode,
1703        is_cuda: bool,
1704        is_metal: bool,
1705        resolved: &ResolvedQwen2TextEncoder,
1706    ) -> Qwen2TextEncoderPlan {
1707        match mode {
1708            Qwen2TextEncoderMode::Gpu => Qwen2TextEncoderPlan {
1709                use_gpu: is_cuda || is_metal,
1710                use_cpu_staging: false,
1711            },
1712            Qwen2TextEncoderMode::CpuStage => Qwen2TextEncoderPlan {
1713                use_gpu: is_cuda || is_metal,
1714                use_cpu_staging: is_cuda || is_metal,
1715            },
1716            Qwen2TextEncoderMode::Cpu => Qwen2TextEncoderPlan {
1717                use_gpu: false,
1718                use_cpu_staging: false,
1719            },
1720            Qwen2TextEncoderMode::Auto => Qwen2TextEncoderPlan {
1721                use_gpu: resolved.auto_use_gpu,
1722                use_cpu_staging: is_metal && resolved.auto_use_gpu && !resolved.is_gguf,
1723            },
1724        }
1725    }
1726
1727    /// Load all model components (Eager mode).
1728    ///
1729    /// On error, `self.base.loaded` remains `None` — all components are assembled into
1730    /// local variables and only stored in `self.base.loaded` on success, so partial loads
1731    /// cannot leave the engine in an inconsistent state.
1732    pub fn load(&mut self) -> Result<()> {
1733        if self.base.loaded.is_some() {
1734            return Ok(());
1735        }
1736
1737        // Sequential mode defers loading to generate_sequential()
1738        if self.base.load_strategy == LoadStrategy::Sequential {
1739            return Ok(());
1740        }
1741
1742        tracing::info!(model = %self.base.model_name, "loading Qwen-Image model components...");
1743
1744        let text_tokenizer_path = self.validate_paths()?;
1745        let transformer_ref = effective_device_ref(
1746            self.pending_placement.as_ref(),
1747            |adv| Some(adv.transformer),
1748            false,
1749        );
1750        let device = crate::device::resolve_device(Some(transformer_ref), || {
1751            crate::device::create_device(self.base.gpu_ordinal, &self.base.progress)
1752        })?;
1753        let transformer_cfg = self.transformer_config();
1754        let transformer_is_quantized = self.detect_is_quantized();
1755        // FP8 safetensors are loaded as BF16 via CPU (candle CUDA kernel bug
1756        // prevents direct F8E4M3→BF16 on GPU; CPU cast works fine). All paths
1757        // use BF16 as runtime dtype since the model trains and computes in BF16.
1758        let dtype = crate::engine::gpu_dtype(&device);
1759
1760        // Load transformer
1761        let xformer_paths = self.transformer_paths();
1762        let xformer_label = if transformer_is_quantized {
1763            "Loading Qwen-Image transformer (quantized)".to_string()
1764        } else {
1765            format!(
1766                "Loading Qwen-Image transformer ({} shards)",
1767                xformer_paths.len()
1768            )
1769        };
1770        self.base.progress.stage_start(&xformer_label);
1771        let xformer_start = Instant::now();
1772        let transformer = self.load_transformer(
1773            &device,
1774            dtype,
1775            &transformer_cfg,
1776            QWEN_NATIVE_WIDTH,
1777            QWEN_NATIVE_HEIGHT,
1778        )?;
1779        self.base
1780            .progress
1781            .stage_done(&xformer_label, xformer_start.elapsed());
1782        tracing::info!("Qwen-Image transformer loaded");
1783
1784        // Decide device placement for VAE and text encoder.
1785        // Log raw, budget against the reserve-adjusted reading.
1786        let free_raw = free_vram_bytes(self.base.gpu_ordinal).unwrap_or(0);
1787        let free = usable_free_vram_bytes(self.base.gpu_ordinal).unwrap_or(0);
1788        let is_cuda = device.is_cuda();
1789        let is_metal = device.is_metal();
1790        if free_raw > 0 {
1791            self.base.progress.info(&format!(
1792                "Free VRAM after transformer: {}",
1793                fmt_gb(free_raw)
1794            ));
1795        }
1796
1797        let vae_on_gpu = should_use_gpu(is_cuda, is_metal, free, VAE_DECODE_VRAM_THRESHOLD);
1798        let vae_ref =
1799            effective_device_ref(self.pending_placement.as_ref(), |adv| Some(adv.vae), false);
1800        let vae_device = crate::device::resolve_device(Some(vae_ref), || {
1801            Ok(if vae_on_gpu {
1802                device.clone()
1803            } else {
1804                Device::Cpu
1805            })
1806        })?;
1807        let vae_on_gpu = !vae_device.is_cpu();
1808        // Always decode in F32 — BF16 convolutions accumulate quantization noise across
1809        // the 4 upsampling blocks, producing visible grain. Matches diffusers' force_upcast.
1810        let vae_dtype = DType::F32;
1811        let vae_device_label = if vae_on_gpu { "GPU" } else { "CPU" };
1812
1813        // Load VAE
1814        let vae_label = format!("Loading Qwen-Image VAE ({}, F32)", vae_device_label);
1815        self.base.progress.stage_start(&vae_label);
1816        let vae_start = Instant::now();
1817        let vae = self.load_vae(&vae_device, vae_dtype)?;
1818        self.base
1819            .progress
1820            .stage_done(&vae_label, vae_start.elapsed());
1821
1822        // Load text encoder
1823        let resolved_text_encoder =
1824            self.resolve_text_encoder_source(&device, free, Qwen2TextEncoderUsage::Resident)?;
1825        let (te_plan, te_auto_device_label) =
1826            self.resolve_text_encoder_plan(&device, &resolved_text_encoder, free);
1827        let qwen_ref = effective_device_ref(self.pending_placement.as_ref(), |adv| adv.qwen, true);
1828        let auto_te_device = if te_plan.use_gpu {
1829            device.clone()
1830        } else {
1831            Device::Cpu
1832        };
1833        let te_device =
1834            crate::device::resolve_device(Some(qwen_ref), || Ok(auto_te_device.clone()))?;
1835        let te_use_gpu = !te_device.is_cpu();
1836        let te_device_label: String = if te_use_gpu == te_plan.use_gpu {
1837            te_auto_device_label
1838        } else if te_use_gpu {
1839            "GPU".into()
1840        } else {
1841            "CPU".into()
1842        };
1843        let te_dtype = Self::text_encoder_load_dtype(te_use_gpu, dtype);
1844
1845        let preload_text_encoder = self.should_preload_text_encoder();
1846        let te_label = if resolved_text_encoder.is_gguf {
1847            if preload_text_encoder {
1848                format!(
1849                    "Loading Qwen2.5 text encoder ({} GGUF, {})",
1850                    resolved_text_encoder.variant_label, te_device_label
1851                )
1852            } else {
1853                format!(
1854                    "Preparing Qwen2.5 text encoder ({} GGUF, {})",
1855                    resolved_text_encoder.variant_label, te_device_label
1856                )
1857            }
1858        } else if preload_text_encoder {
1859            format!(
1860                "Loading Qwen2.5 text encoder ({} shards, {})",
1861                resolved_text_encoder.paths.len(),
1862                te_device_label,
1863            )
1864        } else {
1865            format!(
1866                "Preparing Qwen2.5 text encoder ({} shards, {})",
1867                resolved_text_encoder.paths.len(),
1868                te_device_label,
1869            )
1870        };
1871        self.base.progress.stage_start(&te_label);
1872        let te_start = Instant::now();
1873        let text_tokenizer = self.load_text_tokenizer(&text_tokenizer_path)?;
1874        let text_encoder = self.load_text_encoder(
1875            &resolved_text_encoder,
1876            &text_tokenizer_path,
1877            text_tokenizer,
1878            &te_device,
1879            te_dtype,
1880            preload_text_encoder,
1881        )?;
1882        self.base.progress.stage_done(&te_label, te_start.elapsed());
1883        if preload_text_encoder {
1884            tracing::info!(device = %te_device_label, "Qwen2.5 text encoder loaded");
1885        } else {
1886            tracing::info!(device = %te_device_label, "Qwen2.5 text encoder prepared for staged loading");
1887        }
1888
1889        self.base.loaded = Some(LoadedQwenImage {
1890            transformer: Some(transformer),
1891            text_encoder,
1892            vae,
1893            vae_path: self.base.paths.vae.clone(),
1894            transformer_cfg,
1895            device,
1896            vae_device,
1897            dtype,
1898        });
1899
1900        tracing::info!(model = %self.base.model_name, "all Qwen-Image components loaded");
1901        Ok(())
1902    }
1903
1904    /// Reload the transformer from disk.
1905    fn reload_transformer(
1906        &self,
1907        loaded: &mut LoadedQwenImage,
1908        width: usize,
1909        height: usize,
1910    ) -> Result<()> {
1911        let transformer = self.load_transformer(
1912            &loaded.device,
1913            loaded.dtype,
1914            &loaded.transformer_cfg,
1915            width,
1916            height,
1917        )?;
1918        loaded.transformer = Some(transformer);
1919        Ok(())
1920    }
1921
1922    /// Generate using sequential loading strategy (load-use-drop each component).
1923    fn generate_sequential(&mut self, req: &GenerateRequest) -> Result<GenerateResponse> {
1924        let text_tokenizer_path = self.validate_paths()?;
1925        let transformer_cfg = self.transformer_config();
1926
1927        let transformer_ref = effective_device_ref(
1928            self.pending_placement.as_ref(),
1929            |adv| Some(adv.transformer),
1930            false,
1931        );
1932        let device = crate::device::resolve_device(Some(transformer_ref), || {
1933            crate::device::create_device(self.base.gpu_ordinal, &self.base.progress)
1934        })?;
1935        let dtype = crate::engine::gpu_dtype(&device);
1936        let transformer_is_quantized = self.detect_is_quantized();
1937
1938        let start = Instant::now();
1939        let seed = req.seed.unwrap_or_else(rand_seed);
1940
1941        let width = req.width as usize;
1942        let height = req.height as usize;
1943        // Reserve-adjusted reading: text-encoder source / placement is a
1944        // budget decision.
1945        let free = usable_free_vram_bytes(self.base.gpu_ordinal).unwrap_or(0);
1946        let resolved_text_encoder =
1947            self.resolve_text_encoder_source(&device, free, Qwen2TextEncoderUsage::Sequential)?;
1948        let (plan, _device_label) =
1949            self.resolve_text_encoder_plan(&device, &resolved_text_encoder, free);
1950        let use_cpu_staging = plan.use_cpu_staging;
1951
1952        tracing::info!(
1953            prompt = %req.prompt,
1954            seed, width, height,
1955            steps = req.steps,
1956            "starting sequential Qwen-Image generation"
1957        );
1958
1959        self.base
1960            .progress
1961            .info("Using sequential loading (load-use-drop) to minimize peak memory");
1962
1963        // --- Phase 1: Text encoding (check cache first to skip encoder load) ---
1964        let use_cfg = req.guidance > 1.0;
1965        let prompt_key = prompt_text_key(&req.prompt);
1966        let uncond_key = prompt_text_key(QWEN_EMPTY_NEGATIVE_PROMPT);
1967        let (prompt_cached, uncond_cached) = {
1968            let mut cache = self.prompt_cache.lock().expect("cache poisoned");
1969            let prompt_cached = cache.get_cloned(&prompt_key);
1970            let uncond_cached = if use_cfg {
1971                cache.get_cloned(&uncond_key)
1972            } else {
1973                None
1974            };
1975            (prompt_cached, uncond_cached)
1976        };
1977        let both_cached = prompt_cached.is_some() && (!use_cfg || uncond_cached.is_some());
1978
1979        let (mut encoder_hidden_states, mut encoder_attention_mask, mut uncond_hs, mut uncond_mask) =
1980            if both_cached {
1981                self.base.progress.cache_hit("prompt conditioning");
1982                let cached = prompt_cached.unwrap();
1983                let restore_device = if use_cpu_staging {
1984                    &Device::Cpu
1985                } else {
1986                    &device
1987                };
1988                let restore_dtype = if use_cpu_staging { DType::F32 } else { dtype };
1989                let (hs, mask) = cached.restore(restore_device, restore_dtype)?;
1990                let (u_hs, u_mask) = if use_cfg {
1991                    let ucached = uncond_cached.unwrap();
1992                    let (u_hs, u_mask) = ucached.restore(restore_device, restore_dtype)?;
1993                    (Some(u_hs), Some(u_mask))
1994                } else {
1995                    (None, None)
1996                };
1997                (hs, mask, u_hs, u_mask)
1998            } else {
1999                let (te_plan, te_auto_device_label) =
2000                    self.resolve_text_encoder_plan(&device, &resolved_text_encoder, free);
2001                let qwen_ref =
2002                    effective_device_ref(self.pending_placement.as_ref(), |adv| adv.qwen, true);
2003                let auto_te_device = if te_plan.use_gpu {
2004                    device.clone()
2005                } else {
2006                    Device::Cpu
2007                };
2008                let te_device =
2009                    crate::device::resolve_device(Some(qwen_ref), || Ok(auto_te_device.clone()))?;
2010                let te_use_gpu = !te_device.is_cpu();
2011                let te_device_label: String = if te_use_gpu == te_plan.use_gpu {
2012                    te_auto_device_label
2013                } else if te_use_gpu {
2014                    "GPU".into()
2015                } else {
2016                    "CPU".into()
2017                };
2018                let te_dtype = Self::text_encoder_load_dtype(te_use_gpu, dtype);
2019
2020                let te_label = if resolved_text_encoder.is_gguf {
2021                    format!(
2022                        "Loading Qwen2.5 text encoder ({} GGUF, {})",
2023                        resolved_text_encoder.variant_label, te_device_label
2024                    )
2025                } else {
2026                    format!(
2027                        "Loading Qwen2.5 text encoder ({} shards, {})",
2028                        resolved_text_encoder.paths.len(),
2029                        te_device_label,
2030                    )
2031                };
2032                if te_plan.use_cpu_staging && device.is_metal() && !resolved_text_encoder.is_gguf {
2033                    self.base.progress.info(
2034                        "Skipping hard preflight for Qwen2.5 text encoder on Metal; sequential mode spills prompt conditioning to CPU after encoding",
2035                    );
2036                } else {
2037                    let te_activation_budget = crate::device::activation_bytes(
2038                        req.width,
2039                        req.height,
2040                        1,
2041                        crate::device::dtype_bytes(te_dtype),
2042                        crate::device::ActivationFamily::SmallTransformer,
2043                    );
2044                    preflight_memory_check(
2045                        "Qwen2.5 text encoder",
2046                        resolved_text_encoder.size_bytes,
2047                        te_activation_budget,
2048                    )?;
2049                }
2050
2051                if let Some(status) = memory_status_string() {
2052                    self.base.progress.info(&status);
2053                }
2054
2055                self.base.progress.stage_start(&te_label);
2056                let te_start = Instant::now();
2057                let text_tokenizer = self.load_text_tokenizer(&text_tokenizer_path)?;
2058                let mut text_encoder = self.load_text_encoder(
2059                    &resolved_text_encoder,
2060                    &text_tokenizer_path,
2061                    text_tokenizer,
2062                    &te_device,
2063                    te_dtype,
2064                    true,
2065                )?;
2066                self.base.progress.stage_done(&te_label, te_start.elapsed());
2067
2068                let (hs, mask) = Self::encode_prompt_cached(
2069                    &self.base.progress,
2070                    &self.prompt_cache,
2071                    &mut text_encoder,
2072                    &req.prompt,
2073                    &device,
2074                    dtype,
2075                )?;
2076                let (hs, mask) = Self::maybe_spill_conditioning(use_cpu_staging, hs, mask)?;
2077
2078                let (u_hs, u_mask) = if use_cfg {
2079                    let (hs, mask) = Self::encode_prompt_cached(
2080                        &self.base.progress,
2081                        &self.prompt_cache,
2082                        &mut text_encoder,
2083                        QWEN_EMPTY_NEGATIVE_PROMPT,
2084                        &device,
2085                        dtype,
2086                    )?;
2087                    let (hs, mask) = Self::maybe_spill_conditioning(use_cpu_staging, hs, mask)?;
2088                    (Some(hs), Some(mask))
2089                } else {
2090                    (None, None)
2091                };
2092
2093                drop(text_encoder);
2094                // Force the backend to release allocator state before transformer load.
2095                device.synchronize()?;
2096                if let Some(status) = crate::device::memory_status_string() {
2097                    if use_cpu_staging {
2098                        self.base.progress.info(&format!(
2099                            "Freed Qwen2.5 text encoder and spilled prompt conditioning to CPU — {status}"
2100                        ));
2101                    } else {
2102                        self.base
2103                            .progress
2104                            .info(&format!("Freed Qwen2.5 text encoder — {status}"));
2105                    }
2106                } else {
2107                    if use_cpu_staging {
2108                        self.base.progress.info(
2109                            "Freed Qwen2.5 text encoder and spilled prompt conditioning to CPU",
2110                        );
2111                    } else {
2112                        self.base.progress.info("Freed Qwen2.5 text encoder");
2113                    }
2114                }
2115
2116                (hs, mask, u_hs, u_mask)
2117            };
2118
2119        if use_cfg {
2120            let ((cond_hs, cond_mask), (neg_hs, neg_mask)) = align_cfg_conditioning(
2121                &encoder_hidden_states,
2122                &encoder_attention_mask,
2123                uncond_hs.as_ref().expect("unconditional prompt missing"),
2124                uncond_mask.as_ref().expect("unconditional mask missing"),
2125            )?;
2126            encoder_hidden_states = cond_hs;
2127            encoder_attention_mask = cond_mask;
2128            uncond_hs = Some(neg_hs);
2129            uncond_mask = Some(neg_mask);
2130        }
2131
2132        // --- Phase 2: Load transformer and denoise ---
2133        let xformer_paths = self.transformer_paths();
2134        let xformer_size: u64 = xformer_paths
2135            .iter()
2136            .filter_map(|p| std::fs::metadata(p).ok())
2137            .map(|m| m.len())
2138            .sum();
2139        let xformer_activation_budget = crate::device::activation_bytes(
2140            req.width,
2141            req.height,
2142            if req.guidance > 1.0 { 2 } else { 1 },
2143            crate::device::dtype_bytes(dtype),
2144            crate::device::ActivationFamily::QwenImageDit,
2145        );
2146        preflight_memory_check(
2147            "Qwen-Image transformer",
2148            xformer_size,
2149            xformer_activation_budget,
2150        )?;
2151
2152        if let Some(status) = memory_status_string() {
2153            self.base.progress.info(&status);
2154        }
2155
2156        let xformer_label = if transformer_is_quantized {
2157            "Loading Qwen-Image transformer (quantized)".to_string()
2158        } else {
2159            format!(
2160                "Loading Qwen-Image transformer ({} shards)",
2161                xformer_paths.len()
2162            )
2163        };
2164        self.base.progress.stage_start(&xformer_label);
2165        let xformer_start = Instant::now();
2166        let transformer = self.load_transformer(&device, dtype, &transformer_cfg, width, height)?;
2167        self.base
2168            .progress
2169            .stage_done(&xformer_label, xformer_start.elapsed());
2170
2171        if use_cpu_staging {
2172            encoder_hidden_states = encoder_hidden_states.to_device(&device)?.to_dtype(dtype)?;
2173            encoder_attention_mask = encoder_attention_mask.to_device(&device)?;
2174            if let Some(hs) = uncond_hs.take() {
2175                uncond_hs = Some(hs.to_device(&device)?.to_dtype(dtype)?);
2176            }
2177            if let Some(mask) = uncond_mask.take() {
2178                uncond_mask = Some(mask.to_device(&device)?);
2179            }
2180            if let Some(status) = memory_status_string() {
2181                self.base.progress.info(&format!(
2182                    "Restored prompt conditioning to GPU for denoising — {status}"
2183                ));
2184            } else {
2185                self.base
2186                    .progress
2187                    .info("Restored prompt conditioning to GPU for denoising");
2188            }
2189        }
2190
2191        // Calculate latent dimensions: image_size / 8 (VAE downsample factor)
2192        let vae_downsample = 8;
2193        let latent_h = height / vae_downsample;
2194        let latent_w = width / vae_downsample;
2195        let is_img2img = req.source_image.is_some();
2196
2197        // For img2img, load VAE early to encode source image before transformer
2198        let (prepared_img2img_latents, inpaint_ctx) = if let Some(ref source_bytes) =
2199            req.source_image
2200        {
2201            // Reserve-adjusted reading drives the encode-device decision.
2202            let free_for_encode = usable_free_vram_bytes(self.base.gpu_ordinal).unwrap_or(0);
2203            let encode_on_gpu = should_use_gpu(
2204                device.is_cuda(),
2205                device.is_metal(),
2206                free_for_encode,
2207                VAE_DECODE_VRAM_THRESHOLD,
2208            );
2209            let encode_device = if encode_on_gpu {
2210                device.clone()
2211            } else {
2212                Device::Cpu
2213            };
2214            let encode_label = if encode_on_gpu { "GPU" } else { "CPU" };
2215
2216            let vae_label = format!("Loading Qwen-Image VAE ({}, F32) for encode", encode_label);
2217            self.base.progress.stage_start(&vae_label);
2218            let vae_start = Instant::now();
2219            let encode_vae = self.load_vae(&encode_device, DType::F32)?;
2220            self.base
2221                .progress
2222                .stage_done(&vae_label, vae_start.elapsed());
2223
2224            let encoded = Self::encode_vae_with_fallback(
2225                source_bytes,
2226                req.width,
2227                req.height,
2228                &encode_vae,
2229                &encode_device,
2230                &device,
2231                &self.base.progress,
2232                || self.load_vae(&Device::Cpu, DType::F32),
2233            )?;
2234            let encoded = encoded.to_device(&device)?.to_dtype(dtype)?;
2235            let start_sigma = QwenImageScheduler::new_img2img(
2236                req.steps as usize,
2237                image_seq_len(latent_h, latent_w, transformer_cfg.patch_size),
2238                req.strength,
2239            )
2240            .0
2241            .initial_sigma();
2242            let prepared = crate::img2img::prepare_flow_match_img2img(
2243                &encoded,
2244                seed,
2245                &[1, 16, latent_h, latent_w],
2246                start_sigma,
2247                req.mask_image.as_deref(),
2248                latent_h,
2249                latent_w,
2250                &device,
2251                dtype,
2252            )?;
2253
2254            // Drop early VAE to free memory before transformer load
2255            drop(encode_vae);
2256            device.synchronize()?;
2257
2258            tracing::info!(
2259                strength = req.strength,
2260                "img2img: encoded source image to latents"
2261            );
2262
2263            (Some(prepared.initial_latents), prepared.inpaint_ctx)
2264        } else {
2265            (None, None)
2266        };
2267
2268        let image_seq_len = image_seq_len(latent_h, latent_w, transformer_cfg.patch_size);
2269        let (mut scheduler, num_steps) = if is_img2img {
2270            QwenImageScheduler::new_img2img(req.steps as usize, image_seq_len, req.strength)
2271        } else {
2272            let sched = QwenImageScheduler::new(req.steps as usize, image_seq_len);
2273            let n = sched.num_steps();
2274            (sched, n)
2275        };
2276
2277        // Build initial latents
2278        let mut latents = if let Some(initial) = &prepared_img2img_latents {
2279            initial.clone()
2280        } else {
2281            let noise =
2282                crate::engine::seeded_randn(seed, &[1, 16, latent_h, latent_w], &device, dtype)?;
2283            (noise * scheduler.initial_sigma())?
2284        };
2285
2286        let denoise_label = format!("Denoising ({} steps)", num_steps);
2287        self.base.progress.stage_start(&denoise_label);
2288        let denoise_start = Instant::now();
2289
2290        if std::env::var_os("MOLD_QWEN_DEBUG").is_some() {
2291            eprintln!(
2292                "[qwen-debug] cfg={} guidance={:.1} image_seq_len={} sigmas[0]={:.4} sigmas[last]={:.4} img2img={}",
2293                use_cfg,
2294                req.guidance,
2295                image_seq_len,
2296                scheduler.sigmas[0],
2297                scheduler.sigmas[scheduler.sigmas.len() - 1],
2298                is_img2img,
2299            );
2300        }
2301
2302        let use_batched_cfg = use_cfg && transformer.supports_cfg_batching();
2303        if use_cfg && !use_batched_cfg {
2304            self.base.progress.info(
2305                "Low-memory quantized Qwen CUDA path detected — disabling CFG batching to reduce peak CUDA memory",
2306            );
2307        }
2308
2309        // Pre-batch CFG inputs when the selected transformer path can handle the
2310        // extra batch dimension without exceeding peak memory.
2311        let (batched_hs, batched_mask) = if use_batched_cfg {
2312            let hs = Tensor::cat(&[&encoder_hidden_states, uncond_hs.as_ref().unwrap()], 0)?;
2313            let mask = Tensor::cat(&[&encoder_attention_mask, uncond_mask.as_ref().unwrap()], 0)?;
2314            (hs, mask)
2315        } else {
2316            (
2317                encoder_hidden_states.clone(),
2318                encoder_attention_mask.clone(),
2319            )
2320        };
2321
2322        for step in 0..num_steps {
2323            let step_start = Instant::now();
2324            let t = scheduler.current_timestep();
2325            let noise_pred = if use_cfg {
2326                let (cond_pred, uncond_pred) = if use_batched_cfg {
2327                    let t_tensor =
2328                        Tensor::from_vec(vec![t as f32; 2], (2,), &device)?.to_dtype(dtype)?;
2329                    let batched_latents = Tensor::cat(&[&latents, &latents], 0)?;
2330                    let batched_pred = transformer.forward(
2331                        &batched_latents,
2332                        &t_tensor,
2333                        &batched_hs,
2334                        &batched_mask,
2335                    )?;
2336                    (batched_pred.narrow(0, 0, 1)?, batched_pred.narrow(0, 1, 1)?)
2337                } else {
2338                    let t_tensor =
2339                        Tensor::from_vec(vec![t as f32], (1,), &device)?.to_dtype(dtype)?;
2340                    (
2341                        transformer.forward(
2342                            &latents,
2343                            &t_tensor,
2344                            &encoder_hidden_states,
2345                            &encoder_attention_mask,
2346                        )?,
2347                        transformer.forward(
2348                            &latents,
2349                            &t_tensor,
2350                            uncond_hs.as_ref().unwrap(),
2351                            uncond_mask.as_ref().unwrap(),
2352                        )?,
2353                    )
2354                };
2355                if step == 0 {
2356                    Self::debug_tensor_stats("cond_pred[0]", &cond_pred);
2357                    Self::debug_tensor_stats("uncond_pred[0]", &uncond_pred);
2358                }
2359                // CFG in F32 to avoid BF16 cancellation error, then norm rescale
2360                // to match diffusers' Qwen-Image pipeline.
2361                let cond_f32 = cond_pred.to_dtype(DType::F32)?;
2362                let uncond_f32 = uncond_pred.to_dtype(DType::F32)?;
2363                let comb = (&uncond_f32 + ((&cond_f32 - &uncond_f32)? * req.guidance)?)?;
2364                let cond_norm = cond_f32.sqr()?.sum_keepdim(1)?.sqrt()?;
2365                let comb_norm = comb.sqr()?.sum_keepdim(1)?.sqrt()?.clamp(1e-8, f64::MAX)?;
2366                let rescaled = comb.broadcast_mul(&(cond_norm / comb_norm)?)?;
2367                rescaled.to_dtype(dtype)?
2368            } else {
2369                let t_tensor = Tensor::from_vec(vec![t as f32], (1,), &device)?.to_dtype(dtype)?;
2370                transformer.forward(
2371                    &latents,
2372                    &t_tensor,
2373                    &encoder_hidden_states,
2374                    &encoder_attention_mask,
2375                )?
2376            };
2377            if step == 0 || step == num_steps / 2 || step == num_steps - 1 {
2378                Self::debug_tensor_stats(&format!("noise_pred[{step}]"), &noise_pred);
2379                Self::debug_tensor_stats(&format!("latents[{step}]"), &latents);
2380            }
2381            if step == 0 {
2382                Self::validate_qwen_tensor_boundary("noise_pred[0]", &noise_pred)?;
2383            }
2384            latents = scheduler.step(&noise_pred, &latents)?;
2385            if step == num_steps - 1 {
2386                Self::validate_qwen_tensor_boundary("latents_final", &latents)?;
2387            }
2388
2389            // Inpainting: blend preserved regions back at current noise level
2390            if let Some(ref ctx) = inpaint_ctx {
2391                latents = crate::img2img::apply_flow_match_inpaint(
2392                    &latents,
2393                    ctx,
2394                    scheduler.sigmas[step + 1],
2395                )?;
2396            }
2397
2398            if std::env::var_os("MOLD_QWEN_DEBUG").is_some() {
2399                let n = latents
2400                    .ne(&latents)?
2401                    .to_dtype(candle_core::DType::U32)?
2402                    .sum_all()?
2403                    .to_scalar::<u32>()?;
2404                if n > 0 {
2405                    eprintln!(
2406                        "[qwen-nan] NaN in latents AFTER step {step}: {n}/{}",
2407                        latents.elem_count()
2408                    );
2409                }
2410            }
2411            self.base.progress.emit(ProgressEvent::DenoiseStep {
2412                step: step + 1,
2413                total: num_steps,
2414                elapsed: step_start.elapsed(),
2415            });
2416        }
2417
2418        self.base
2419            .progress
2420            .stage_done(&denoise_label, denoise_start.elapsed());
2421
2422        // Drop transformer and embeddings
2423        drop(transformer);
2424        drop(encoder_hidden_states);
2425        drop(encoder_attention_mask);
2426        drop(uncond_hs);
2427        drop(uncond_mask);
2428        device.synchronize()?;
2429        self.base.progress.info("Freed Qwen-Image transformer");
2430
2431        // --- Phase 3: Load VAE and decode ---
2432        if let Some(status) = memory_status_string() {
2433            self.base.progress.info(&status);
2434        }
2435
2436        // Reserve-adjusted reading: VAE placement is a budget decision.
2437        let free_for_vae = usable_free_vram_bytes(self.base.gpu_ordinal).unwrap_or(0);
2438        let vae_on_gpu = should_use_gpu(
2439            device.is_cuda(),
2440            device.is_metal(),
2441            free_for_vae,
2442            VAE_DECODE_VRAM_THRESHOLD,
2443        );
2444        let vae_ref =
2445            effective_device_ref(self.pending_placement.as_ref(), |adv| Some(adv.vae), false);
2446        let vae_device = crate::device::resolve_device(Some(vae_ref), || {
2447            Ok(if vae_on_gpu {
2448                device.clone()
2449            } else {
2450                Device::Cpu
2451            })
2452        })?;
2453        let vae_on_gpu = !vae_device.is_cpu();
2454        // Always decode in F32 — BF16 convolutions accumulate quantization noise across
2455        // the 4 upsampling blocks, producing visible grain. Matches diffusers' force_upcast.
2456        let vae_dtype = DType::F32;
2457        let vae_device_label = if vae_on_gpu { "GPU" } else { "CPU" };
2458
2459        let vae_label = format!("Loading Qwen-Image VAE ({}, F32)", vae_device_label);
2460        self.base.progress.stage_start(&vae_label);
2461        let vae_start = Instant::now();
2462        let vae = self.load_vae(&vae_device, vae_dtype)?;
2463        self.base
2464            .progress
2465            .stage_done(&vae_label, vae_start.elapsed());
2466
2467        self.base.progress.stage_start("VAE decode");
2468        let vae_decode_start = Instant::now();
2469        let free_for_decode = usable_free_vram_bytes(self.base.gpu_ordinal).unwrap_or(0);
2470        let prefer_tiled = Self::should_proactively_tile_vae_decode(
2471            req.width,
2472            req.height,
2473            vae_device.is_cuda(),
2474            free_for_decode,
2475        );
2476
2477        let image = Self::decode_vae_with_fallback(
2478            &latents,
2479            &vae,
2480            &vae_device,
2481            &device,
2482            &self.base.progress,
2483            prefer_tiled,
2484            || self.load_vae(&Device::Cpu, DType::F32),
2485        )?;
2486        Self::validate_qwen_tensor_boundary("image_pre_postprocess", &image)?;
2487        Self::debug_tensor_stats("image_pre_postprocess", &image);
2488        let image = postprocess_image(&image)?;
2489        let post_stats = Self::validate_qwen_tensor_boundary("image_postprocess", &image)?;
2490        Self::debug_tensor_stats("image_postprocess", &image);
2491        let image = image.i(0)?;
2492        if Self::near_black_image_stats(post_stats) {
2493            self.base.progress.info(
2494                "Qwen diagnostic: decoded image is near-black after VAE postprocess; inspect MOLD_QWEN_DEBUG tensor stats to separate denoise math from VAE decode",
2495            );
2496            tracing::warn!(
2497                min = post_stats.min,
2498                max = post_stats.max,
2499                mean = post_stats.mean,
2500                "Qwen decoded image is near-black after VAE postprocess"
2501            );
2502        }
2503
2504        self.base
2505            .progress
2506            .stage_done("VAE decode", vae_decode_start.elapsed());
2507
2508        let output_metadata = build_output_metadata(req, seed, None);
2509        let image_bytes = encode_image(
2510            &image,
2511            req.resolved_output_format(),
2512            req.width,
2513            req.height,
2514            output_metadata.as_ref(),
2515        )?;
2516
2517        let generation_time_ms = start.elapsed().as_millis() as u64;
2518        tracing::info!(
2519            generation_time_ms,
2520            seed,
2521            "sequential Qwen-Image generation complete"
2522        );
2523
2524        Ok(GenerateResponse {
2525            images: vec![ImageData {
2526                data: image_bytes,
2527                format: req.resolved_output_format(),
2528                width: req.width,
2529                height: req.height,
2530                index: 0,
2531            }],
2532            generation_time_ms,
2533            model: req.model.clone(),
2534            seed_used: seed,
2535            video: None,
2536            gpu: None,
2537        })
2538    }
2539
2540    fn generate_edit_loaded(&mut self, req: &GenerateRequest) -> Result<GenerateResponse> {
2541        let progress = &self.base.progress;
2542        let start = Instant::now();
2543
2544        let loaded_ref = self
2545            .base
2546            .loaded
2547            .as_ref()
2548            .ok_or_else(|| anyhow::anyhow!("model not loaded"))?;
2549        let needs_reload = loaded_ref.transformer.is_none();
2550        if needs_reload {
2551            let mut loaded_mut = self
2552                .base
2553                .loaded
2554                .take()
2555                .ok_or_else(|| anyhow::anyhow!("model not loaded"))?;
2556            progress.stage_start("Reloading Qwen-Image transformer");
2557            let reload_start = Instant::now();
2558            self.reload_transformer(&mut loaded_mut, req.width as usize, req.height as usize)?;
2559            progress.stage_done("Reloading Qwen-Image transformer", reload_start.elapsed());
2560            self.base.loaded = Some(loaded_mut);
2561        }
2562
2563        let is_edit_family = self.is_edit_family();
2564        let loaded = self
2565            .base
2566            .loaded
2567            .as_mut()
2568            .ok_or_else(|| anyhow::anyhow!("model not loaded"))?;
2569        let seed = req.seed.unwrap_or_else(rand_seed);
2570        let width = req.width as usize;
2571        let height = req.height as usize;
2572        let edit_images = req
2573            .edit_images
2574            .as_ref()
2575            .ok_or_else(|| anyhow::anyhow!("qwen-image-edit requires edit_images"))?;
2576        let use_cfg = req.guidance > 1.0;
2577        let negative_prompt = req
2578            .negative_prompt
2579            .as_deref()
2580            .unwrap_or(QWEN_EMPTY_NEGATIVE_PROMPT);
2581        let formatted_prompt = Self::qwen_image_edit_prompt(&req.prompt, edit_images.len());
2582        let formatted_negative = Self::qwen_image_edit_prompt(negative_prompt, edit_images.len());
2583
2584        tracing::info!(
2585            prompt = %req.prompt,
2586            seed,
2587            width,
2588            height,
2589            steps = req.steps,
2590            edit_images = edit_images.len(),
2591            "starting Qwen-Image edit generation"
2592        );
2593
2594        if loaded.text_encoder.model.is_none() {
2595            let label = if loaded.text_encoder.is_parked() {
2596                "Unparking Qwen2.5 encoder (CPU→GPU)"
2597            } else {
2598                "Reloading Qwen2.5 encoder"
2599            };
2600            progress.stage_start(label);
2601            let reload_start = Instant::now();
2602            if loaded.text_encoder.is_parked() {
2603                loaded.text_encoder.unpark_to_gpu(progress)?;
2604            } else {
2605                loaded.text_encoder.reload(progress)?;
2606            }
2607            progress.stage_done(label, reload_start.elapsed());
2608        }
2609
2610        progress.stage_start("Encoding prompt (Qwen2.5 edit)");
2611        let encode_start = Instant::now();
2612        let (encoder_hidden_states, encoder_attention_mask, _) =
2613            loaded.text_encoder.encode_formatted_multimodal(
2614                &formatted_prompt,
2615                edit_images,
2616                &loaded.device,
2617                loaded.dtype,
2618            )?;
2619        progress.stage_done("Encoding prompt (Qwen2.5 edit)", encode_start.elapsed());
2620        let (encoder_hidden_states, encoder_attention_mask, uncond_hs, uncond_mask) = if use_cfg {
2621            progress.stage_start("Encoding negative prompt (Qwen2.5 edit)");
2622            let neg_start = Instant::now();
2623            let (hs, mask, _) = loaded.text_encoder.encode_formatted_multimodal(
2624                &formatted_negative,
2625                edit_images,
2626                &loaded.device,
2627                loaded.dtype,
2628            )?;
2629            progress.stage_done(
2630                "Encoding negative prompt (Qwen2.5 edit)",
2631                neg_start.elapsed(),
2632            );
2633            let ((cond_hs, cond_mask), (neg_hs, neg_mask)) = align_cfg_conditioning(
2634                &encoder_hidden_states,
2635                &encoder_attention_mask,
2636                &hs,
2637                &mask,
2638            )?;
2639            (cond_hs, cond_mask, Some(neg_hs), Some(neg_mask))
2640        } else {
2641            (encoder_hidden_states, encoder_attention_mask, None, None)
2642        };
2643
2644        let drop_text_encoder = is_edit_family || loaded.text_encoder.on_gpu;
2645        if drop_text_encoder {
2646            let park_mode = crate::device::keep_te_in_ram()
2647                && !loaded.device.is_metal()
2648                && !loaded.text_encoder.is_quantized;
2649            if park_mode {
2650                loaded.text_encoder.park_to_cpu()?;
2651                tracing::info!(
2652                    on_gpu = loaded.text_encoder.on_gpu,
2653                    "Qwen2.5 text encoder parked to CPU host RAM after edit conditioning"
2654                );
2655            } else {
2656                loaded.text_encoder.drop_weights();
2657                tracing::info!(
2658                    on_gpu = loaded.text_encoder.on_gpu,
2659                    "Qwen2.5 text encoder dropped after edit conditioning"
2660                );
2661            }
2662        }
2663
2664        let mut packed_input_storage = Vec::with_capacity(edit_images.len());
2665        let mut img_shapes = vec![(1usize, height / 16, width / 16)];
2666        progress.stage_start("Encoding edit images (VAE)");
2667        let encode_start = Instant::now();
2668        for image_bytes in edit_images {
2669            let (vae_width, vae_height) =
2670                Self::qwen_image_edit_image_dims(image_bytes, QWEN_IMAGE_EDIT_VAE_AREA)?;
2671            let encoded = Self::encode_vae_with_fallback(
2672                image_bytes,
2673                vae_width,
2674                vae_height,
2675                &loaded.vae,
2676                &loaded.vae_device,
2677                &loaded.device,
2678                progress,
2679                || {
2680                    Ok(QwenImageVae::load(
2681                        &loaded.vae_path,
2682                        &Device::Cpu,
2683                        DType::F32,
2684                        progress,
2685                    )?)
2686                },
2687            )?
2688            .to_device(&loaded.device)?
2689            .to_dtype(loaded.dtype)?;
2690            img_shapes.push((1, encoded.dim(2)? / 2, encoded.dim(3)? / 2));
2691            packed_input_storage.push(Self::pack_latents_4d(&encoded)?);
2692        }
2693        progress.stage_done("Encoding edit images (VAE)", encode_start.elapsed());
2694
2695        let packed_inputs = if packed_input_storage.is_empty() {
2696            None
2697        } else {
2698            let tensors = packed_input_storage.iter().collect::<Vec<_>>();
2699            Some(Tensor::cat(&tensors, 1)?)
2700        };
2701
2702        let noise = crate::engine::seeded_randn(
2703            seed,
2704            &[1, 16, height / 8, width / 8],
2705            &loaded.device,
2706            loaded.dtype,
2707        )?;
2708        let mut scheduler =
2709            QwenImageScheduler::new(req.steps as usize, (height / 16) * (width / 16));
2710        let num_steps = scheduler.num_steps();
2711        let mut latents = Self::pack_latents_4d(&(noise * scheduler.initial_sigma())?)?;
2712        let output_seq_len = latents.dim(1)?;
2713
2714        let denoise_label = format!("Denoising edit ({} steps)", num_steps);
2715        progress.stage_start(&denoise_label);
2716        let denoise_start = Instant::now();
2717
2718        {
2719            let transformer = loaded
2720                .transformer
2721                .as_ref()
2722                .expect("transformer must be loaded for denoising");
2723            let use_batched_cfg = use_cfg && transformer.supports_cfg_batching();
2724            let (batched_hs, batched_mask) = if use_batched_cfg {
2725                let hs = Tensor::cat(&[&encoder_hidden_states, uncond_hs.as_ref().unwrap()], 0)?;
2726                let mask =
2727                    Tensor::cat(&[&encoder_attention_mask, uncond_mask.as_ref().unwrap()], 0)?;
2728                (hs, mask)
2729            } else {
2730                (
2731                    encoder_hidden_states.clone(),
2732                    encoder_attention_mask.clone(),
2733                )
2734            };
2735
2736            for step in 0..num_steps {
2737                let step_start = Instant::now();
2738                let t = scheduler.current_timestep();
2739                let timestep = if use_batched_cfg {
2740                    Tensor::from_vec(vec![t as f32; 2], (2,), &loaded.device)?
2741                        .to_dtype(loaded.dtype)?
2742                } else {
2743                    Tensor::from_vec(vec![t as f32], (1,), &loaded.device)?
2744                        .to_dtype(loaded.dtype)?
2745                };
2746
2747                let latent_model_input = if let Some(ref packed_inputs) = packed_inputs {
2748                    Tensor::cat(&[&latents, packed_inputs], 1)?
2749                } else {
2750                    latents.clone()
2751                };
2752
2753                let noise_pred = if use_cfg {
2754                    let (cond_pred, uncond_pred) = if use_batched_cfg {
2755                        let batched_input =
2756                            Tensor::cat(&[&latent_model_input, &latent_model_input], 0)?;
2757                        let pred = transformer.forward_packed(
2758                            &batched_input,
2759                            &timestep,
2760                            &batched_hs,
2761                            &batched_mask,
2762                            &img_shapes,
2763                        )?;
2764                        (
2765                            pred.narrow(0, 0, 1)?.narrow(1, 0, output_seq_len)?,
2766                            pred.narrow(0, 1, 1)?.narrow(1, 0, output_seq_len)?,
2767                        )
2768                    } else {
2769                        (
2770                            transformer
2771                                .forward_packed(
2772                                    &latent_model_input,
2773                                    &timestep,
2774                                    &encoder_hidden_states,
2775                                    &encoder_attention_mask,
2776                                    &img_shapes,
2777                                )?
2778                                .narrow(1, 0, output_seq_len)?,
2779                            transformer
2780                                .forward_packed(
2781                                    &latent_model_input,
2782                                    &timestep,
2783                                    uncond_hs.as_ref().unwrap(),
2784                                    uncond_mask.as_ref().unwrap(),
2785                                    &img_shapes,
2786                                )?
2787                                .narrow(1, 0, output_seq_len)?,
2788                        )
2789                    };
2790
2791                    let cond_f32 = cond_pred.to_dtype(DType::F32)?;
2792                    let uncond_f32 = uncond_pred.to_dtype(DType::F32)?;
2793                    let comb = (&uncond_f32 + ((&cond_f32 - &uncond_f32)? * req.guidance)?)?;
2794                    let cond_norm = cond_f32.sqr()?.sum_keepdim(D::Minus1)?.sqrt()?;
2795                    let comb_norm = comb
2796                        .sqr()?
2797                        .sum_keepdim(D::Minus1)?
2798                        .sqrt()?
2799                        .clamp(1e-8, f64::MAX)?;
2800                    comb.broadcast_mul(&(cond_norm / comb_norm)?)?
2801                        .to_dtype(loaded.dtype)?
2802                } else {
2803                    transformer
2804                        .forward_packed(
2805                            &latent_model_input,
2806                            &timestep,
2807                            &encoder_hidden_states,
2808                            &encoder_attention_mask,
2809                            &img_shapes,
2810                        )?
2811                        .narrow(1, 0, output_seq_len)?
2812                };
2813
2814                latents = scheduler.step(&noise_pred, &latents)?;
2815                progress.emit(ProgressEvent::DenoiseStep {
2816                    step: step + 1,
2817                    total: num_steps,
2818                    elapsed: step_start.elapsed(),
2819                });
2820            }
2821        }
2822
2823        progress.stage_done(&denoise_label, denoise_start.elapsed());
2824
2825        let latents = Self::unpack_latents_packed(&latents, height / 8, width / 8)?;
2826        let free_for_decode = usable_free_vram_bytes(self.base.gpu_ordinal).unwrap_or(0);
2827        let prefer_tiled = Self::should_proactively_tile_vae_decode(
2828            req.width,
2829            req.height,
2830            loaded.vae_device.is_cuda(),
2831            free_for_decode,
2832        );
2833        let image = Self::decode_vae_with_fallback(
2834            &latents,
2835            &loaded.vae,
2836            &loaded.vae_device,
2837            &loaded.device,
2838            progress,
2839            prefer_tiled,
2840            || {
2841                Ok(QwenImageVae::load(
2842                    &loaded.vae_path,
2843                    &Device::Cpu,
2844                    DType::F32,
2845                    progress,
2846                )?)
2847            },
2848        )?;
2849        let image = postprocess_image(&image)?.i(0)?;
2850        let output_metadata = build_output_metadata(req, seed, None);
2851        let image_bytes = encode_image(
2852            &image,
2853            req.resolved_output_format(),
2854            req.width,
2855            req.height,
2856            output_metadata.as_ref(),
2857        )?;
2858
2859        Ok(GenerateResponse {
2860            images: vec![ImageData {
2861                data: image_bytes,
2862                format: req.resolved_output_format(),
2863                width: req.width,
2864                height: req.height,
2865                index: 0,
2866            }],
2867            generation_time_ms: start.elapsed().as_millis() as u64,
2868            model: req.model.clone(),
2869            seed_used: seed,
2870            video: None,
2871            gpu: None,
2872        })
2873    }
2874}
2875
2876impl QwenImageEngine {
2877    fn generate_inner(&mut self, req: &GenerateRequest) -> Result<GenerateResponse> {
2878        if req.scheduler.is_some() {
2879            tracing::warn!(
2880                "scheduler selection not supported for Qwen-Image (flow-matching), ignoring"
2881            );
2882        }
2883
2884        if self.is_edit_family() {
2885            let sequential = self.base.load_strategy == LoadStrategy::Sequential;
2886            if sequential && self.base.loaded.is_none() {
2887                let original = self.base.load_strategy;
2888                self.base.load_strategy = LoadStrategy::Eager;
2889                let load_result = self.load();
2890                self.base.load_strategy = original;
2891                load_result?;
2892            }
2893            if self.base.loaded.is_none() {
2894                bail!("model not loaded -- call load() first");
2895            }
2896            let result = self.generate_edit_loaded(req);
2897            if sequential {
2898                self.unload();
2899            }
2900            return result;
2901        }
2902
2903        // Sequential mode: load-use-drop each component
2904        if self.base.load_strategy == LoadStrategy::Sequential {
2905            return self.generate_sequential(req);
2906        }
2907
2908        // Eager mode: use pre-loaded components
2909        if self.base.loaded.is_none() {
2910            bail!("model not loaded -- call load() first");
2911        }
2912
2913        let progress = &self.base.progress;
2914        let gpu_ordinal = self.base.gpu_ordinal;
2915        let start = Instant::now();
2916
2917        // Reload transformer if it was dropped after previous VAE decode
2918        let loaded_ref = self
2919            .base
2920            .loaded
2921            .as_ref()
2922            .ok_or_else(|| anyhow::anyhow!("model not loaded"))?;
2923        let needs_reload = loaded_ref.transformer.is_none();
2924        if needs_reload {
2925            let mut loaded_mut = self
2926                .base
2927                .loaded
2928                .take()
2929                .ok_or_else(|| anyhow::anyhow!("model not loaded"))?;
2930            progress.stage_start("Reloading Qwen-Image transformer");
2931            let reload_start = Instant::now();
2932            self.reload_transformer(&mut loaded_mut, req.width as usize, req.height as usize)?;
2933            progress.stage_done("Reloading Qwen-Image transformer", reload_start.elapsed());
2934            self.base.loaded = Some(loaded_mut);
2935        }
2936
2937        let loaded = self
2938            .base
2939            .loaded
2940            .as_mut()
2941            .ok_or_else(|| anyhow::anyhow!("model not loaded"))?;
2942        let seed = req.seed.unwrap_or_else(rand_seed);
2943
2944        let width = req.width as usize;
2945        let height = req.height as usize;
2946
2947        tracing::info!(
2948            prompt = %req.prompt,
2949            seed, width, height,
2950            steps = req.steps,
2951            "starting Qwen-Image generation"
2952        );
2953
2954        let use_cfg = req.guidance > 1.0;
2955        let prompt_key = prompt_text_key(&req.prompt);
2956        let uncond_key = prompt_text_key(QWEN_EMPTY_NEGATIVE_PROMPT);
2957        let prompt_cached = self
2958            .prompt_cache
2959            .lock()
2960            .expect("cache poisoned")
2961            .get_cloned(&prompt_key);
2962        let uncond_cached = if use_cfg {
2963            self.prompt_cache
2964                .lock()
2965                .expect("cache poisoned")
2966                .get_cloned(&uncond_key)
2967        } else {
2968            None
2969        };
2970        let both_cached = prompt_cached.is_some() && (!use_cfg || uncond_cached.is_some());
2971
2972        let (encoder_hidden_states, encoder_attention_mask, uncond_hs, uncond_mask) = if both_cached
2973        {
2974            let cached = prompt_cached.expect("prompt cache unexpectedly missing");
2975            progress.cache_hit("prompt conditioning");
2976            let (hs, mask) = cached.restore(&loaded.device, loaded.dtype)?;
2977            let (u_hs, u_mask) = if use_cfg {
2978                progress.cache_hit("unconditional conditioning");
2979                let ucached =
2980                    uncond_cached.expect("unconditional prompt cache unexpectedly missing");
2981                let (u_hs, u_mask) = ucached.restore(&loaded.device, loaded.dtype)?;
2982                (Some(u_hs), Some(u_mask))
2983            } else {
2984                (None, None)
2985            };
2986            (hs, mask, u_hs, u_mask)
2987        } else {
2988            if loaded.text_encoder.model.is_none() {
2989                let label = if loaded.text_encoder.is_parked() {
2990                    "Unparking Qwen2.5 encoder (CPU→GPU)"
2991                } else {
2992                    "Reloading Qwen2.5 encoder"
2993                };
2994                progress.stage_start(label);
2995                let reload_start = Instant::now();
2996                if loaded.text_encoder.is_parked() {
2997                    loaded.text_encoder.unpark_to_gpu(progress)?;
2998                } else {
2999                    loaded.text_encoder.reload(progress)?;
3000                }
3001                progress.stage_done(label, reload_start.elapsed());
3002            }
3003
3004            let (hs, mask) = Self::encode_prompt_cached(
3005                progress,
3006                &self.prompt_cache,
3007                &mut loaded.text_encoder,
3008                &req.prompt,
3009                &loaded.device,
3010                loaded.dtype,
3011            )?;
3012
3013            let (u_hs, u_mask) = if use_cfg {
3014                let (hs, mask) = Self::encode_prompt_cached(
3015                    progress,
3016                    &self.prompt_cache,
3017                    &mut loaded.text_encoder,
3018                    QWEN_EMPTY_NEGATIVE_PROMPT,
3019                    &loaded.device,
3020                    loaded.dtype,
3021                )?;
3022                (Some(hs), Some(mask))
3023            } else {
3024                (None, None)
3025            };
3026
3027            (hs, mask, u_hs, u_mask)
3028        };
3029
3030        let (encoder_hidden_states, encoder_attention_mask, uncond_hs, uncond_mask) = if use_cfg {
3031            let ((cond_hs, cond_mask), (neg_hs, neg_mask)) = align_cfg_conditioning(
3032                &encoder_hidden_states,
3033                &encoder_attention_mask,
3034                uncond_hs.as_ref().expect("unconditional prompt missing"),
3035                uncond_mask.as_ref().expect("unconditional mask missing"),
3036            )?;
3037            (cond_hs, cond_mask, Some(neg_hs), Some(neg_mask))
3038        } else {
3039            (
3040                encoder_hidden_states,
3041                encoder_attention_mask,
3042                uncond_hs,
3043                uncond_mask,
3044            )
3045        };
3046
3047        // Drop or park text encoder to free VRAM for denoising.
3048        if loaded.text_encoder.on_gpu {
3049            let free_after_encode = usable_free_vram_bytes(gpu_ordinal).unwrap_or(0);
3050            let required_for_residency = Self::qwen2_hot_text_encoder_required_vram(
3051                req.width,
3052                req.height,
3053                if req.guidance > 1.0 { 2 } else { 1 },
3054                loaded.dtype,
3055            );
3056            let action =
3057                Self::qwen2_text_encoder_post_encode_action(Qwen2TextEncoderResidencyInput {
3058                    on_gpu: loaded.text_encoder.on_gpu,
3059                    is_quantized: loaded.text_encoder.is_quantized,
3060                    is_metal: loaded.device.is_metal(),
3061                    keep_te_ram: crate::device::keep_te_in_ram(),
3062                    prompt_cache_miss: !both_cached,
3063                    transformer_resident: loaded.transformer.is_some(),
3064                    free_vram_bytes: free_after_encode,
3065                    required_vram_bytes: required_for_residency,
3066                });
3067            match action {
3068                Qwen2TextEncoderPostEncodeAction::KeepGpu => {
3069                    progress.info(&format!(
3070                        "Keeping Qwen2.5 text encoder on GPU for hot prompt-cache misses ({} free >= {} reserve)",
3071                        fmt_gb(free_after_encode),
3072                        fmt_gb(required_for_residency)
3073                    ));
3074                    tracing::info!(
3075                        free_vram_bytes = free_after_encode,
3076                        required_vram_bytes = required_for_residency,
3077                        is_quantized = loaded.text_encoder.is_quantized,
3078                        "Qwen2.5 text encoder kept on GPU after cache miss"
3079                    );
3080                }
3081                Qwen2TextEncoderPostEncodeAction::ParkCpu => {
3082                    loaded.text_encoder.park_to_cpu()?;
3083                    progress.info(&format!(
3084                        "Parked Qwen2.5 text encoder to CPU host RAM before denoise ({} free < {} reserve)",
3085                        fmt_gb(free_after_encode),
3086                        fmt_gb(required_for_residency)
3087                    ));
3088                    tracing::info!("Qwen2.5 text encoder parked to CPU host RAM");
3089                }
3090                Qwen2TextEncoderPostEncodeAction::Drop => {
3091                    loaded.text_encoder.drop_weights();
3092                    progress.info(&format!(
3093                        "Dropped Qwen2.5 text encoder before denoise ({} free < {} reserve or cache hit)",
3094                        fmt_gb(free_after_encode),
3095                        fmt_gb(required_for_residency)
3096                    ));
3097                    tracing::info!("Qwen2.5 text encoder dropped from GPU");
3098                }
3099            }
3100        }
3101
3102        // 3. Calculate latent dimensions
3103        let vae_downsample = 8;
3104        let latent_h = height / vae_downsample;
3105        let latent_w = width / vae_downsample;
3106        let is_img2img = req.source_image.is_some();
3107
3108        // For img2img, encode source image using the pre-loaded VAE
3109        let (prepared_img2img_latents, inpaint_ctx) =
3110            if let Some(ref source_bytes) = req.source_image {
3111                let encoded = Self::encode_vae_with_fallback(
3112                    source_bytes,
3113                    req.width,
3114                    req.height,
3115                    &loaded.vae,
3116                    &loaded.vae_device,
3117                    &loaded.device,
3118                    progress,
3119                    || {
3120                        Ok(QwenImageVae::load(
3121                            &loaded.vae_path,
3122                            &Device::Cpu,
3123                            DType::F32,
3124                            progress,
3125                        )?)
3126                    },
3127                )?;
3128                let encoded = encoded.to_device(&loaded.device)?.to_dtype(loaded.dtype)?;
3129                let start_sigma = QwenImageScheduler::new_img2img(
3130                    req.steps as usize,
3131                    image_seq_len(latent_h, latent_w, loaded.transformer_cfg.patch_size),
3132                    req.strength,
3133                )
3134                .0
3135                .initial_sigma();
3136                let prepared = crate::img2img::prepare_flow_match_img2img(
3137                    &encoded,
3138                    seed,
3139                    &[1, 16, latent_h, latent_w],
3140                    start_sigma,
3141                    req.mask_image.as_deref(),
3142                    latent_h,
3143                    latent_w,
3144                    &loaded.device,
3145                    loaded.dtype,
3146                )?;
3147
3148                (Some(prepared.initial_latents), prepared.inpaint_ctx)
3149            } else {
3150                (None, None)
3151            };
3152
3153        // 4. Initialize scheduler
3154        let image_seq_len = image_seq_len(latent_h, latent_w, loaded.transformer_cfg.patch_size);
3155        let (mut scheduler, num_steps) = if is_img2img {
3156            QwenImageScheduler::new_img2img(req.steps as usize, image_seq_len, req.strength)
3157        } else {
3158            let sched = QwenImageScheduler::new(req.steps as usize, image_seq_len);
3159            let n = sched.num_steps();
3160            (sched, n)
3161        };
3162
3163        // 5. Build initial latents
3164        let mut latents = if let Some(initial) = &prepared_img2img_latents {
3165            initial.clone()
3166        } else {
3167            let noise = crate::engine::seeded_randn(
3168                seed,
3169                &[1, 16, latent_h, latent_w],
3170                &loaded.device,
3171                loaded.dtype,
3172            )?;
3173            (noise * scheduler.initial_sigma())?
3174        };
3175
3176        // 7. Denoising loop
3177        let denoise_label = format!("Denoising ({} steps)", num_steps);
3178        progress.stage_start(&denoise_label);
3179        let denoise_start = Instant::now();
3180
3181        {
3182            let transformer = loaded
3183                .transformer
3184                .as_ref()
3185                .expect("transformer must be loaded for denoising");
3186
3187            let use_batched_cfg = use_cfg && transformer.supports_cfg_batching();
3188            if use_cfg && !use_batched_cfg {
3189                progress.info(
3190                    "Low-memory quantized Qwen CUDA path detected — disabling CFG batching to reduce peak CUDA memory",
3191                );
3192            }
3193
3194            // Pre-batch CFG inputs when the selected transformer path can handle
3195            // the extra batch dimension without exceeding peak memory.
3196            let (batched_hs, batched_mask) = if use_batched_cfg {
3197                let hs = Tensor::cat(&[&encoder_hidden_states, uncond_hs.as_ref().unwrap()], 0)?;
3198                let mask =
3199                    Tensor::cat(&[&encoder_attention_mask, uncond_mask.as_ref().unwrap()], 0)?;
3200                (hs, mask)
3201            } else {
3202                (
3203                    encoder_hidden_states.clone(),
3204                    encoder_attention_mask.clone(),
3205                )
3206            };
3207
3208            for step in 0..num_steps {
3209                let step_start = Instant::now();
3210                let t = scheduler.current_timestep();
3211                let noise_pred = if use_cfg {
3212                    let (cond_pred, uncond_pred) = if use_batched_cfg {
3213                        let t_tensor = Tensor::from_vec(vec![t as f32; 2], (2,), &loaded.device)?
3214                            .to_dtype(loaded.dtype)?;
3215                        let batched_latents = Tensor::cat(&[&latents, &latents], 0)?;
3216                        let batched_pred = transformer.forward(
3217                            &batched_latents,
3218                            &t_tensor,
3219                            &batched_hs,
3220                            &batched_mask,
3221                        )?;
3222                        (batched_pred.narrow(0, 0, 1)?, batched_pred.narrow(0, 1, 1)?)
3223                    } else {
3224                        let t_tensor = Tensor::from_vec(vec![t as f32], (1,), &loaded.device)?
3225                            .to_dtype(loaded.dtype)?;
3226                        (
3227                            transformer.forward(
3228                                &latents,
3229                                &t_tensor,
3230                                &encoder_hidden_states,
3231                                &encoder_attention_mask,
3232                            )?,
3233                            transformer.forward(
3234                                &latents,
3235                                &t_tensor,
3236                                uncond_hs.as_ref().unwrap(),
3237                                uncond_mask.as_ref().unwrap(),
3238                            )?,
3239                        )
3240                    };
3241                    // CFG in F32 + norm rescale (matches diffusers Qwen-Image pipeline)
3242                    let cond_f32 = cond_pred.to_dtype(DType::F32)?;
3243                    let uncond_f32 = uncond_pred.to_dtype(DType::F32)?;
3244                    let comb = (&uncond_f32 + ((&cond_f32 - &uncond_f32)? * req.guidance)?)?;
3245                    let cond_norm = cond_f32.sqr()?.sum_keepdim(1)?.sqrt()?;
3246                    let comb_norm = comb.sqr()?.sum_keepdim(1)?.sqrt()?.clamp(1e-8, f64::MAX)?;
3247                    let rescaled = comb.broadcast_mul(&(cond_norm / comb_norm)?)?;
3248                    rescaled.to_dtype(loaded.dtype)?
3249                } else {
3250                    let t_tensor = Tensor::from_vec(vec![t as f32], (1,), &loaded.device)?
3251                        .to_dtype(loaded.dtype)?;
3252                    transformer.forward(
3253                        &latents,
3254                        &t_tensor,
3255                        &encoder_hidden_states,
3256                        &encoder_attention_mask,
3257                    )?
3258                };
3259                if step == 0 || step == num_steps / 2 || step == num_steps - 1 {
3260                    Self::debug_tensor_stats(&format!("noise_pred[{step}]"), &noise_pred);
3261                    Self::debug_tensor_stats(&format!("latents[{step}]"), &latents);
3262                }
3263                if step == 0 {
3264                    Self::validate_qwen_tensor_boundary("noise_pred[0]", &noise_pred)?;
3265                }
3266                latents = scheduler.step(&noise_pred, &latents)?;
3267                if step == num_steps - 1 {
3268                    Self::validate_qwen_tensor_boundary("latents_final", &latents)?;
3269                }
3270
3271                // Inpainting: blend preserved regions back at current noise level
3272                if let Some(ref ctx) = inpaint_ctx {
3273                    latents = crate::img2img::apply_flow_match_inpaint(
3274                        &latents,
3275                        ctx,
3276                        scheduler.sigmas[step + 1],
3277                    )?;
3278                }
3279
3280                progress.emit(ProgressEvent::DenoiseStep {
3281                    step: step + 1,
3282                    total: num_steps,
3283                    elapsed: step_start.elapsed(),
3284                });
3285            }
3286        }
3287
3288        progress.stage_done(&denoise_label, denoise_start.elapsed());
3289
3290        // Free text embeddings
3291        drop(encoder_hidden_states);
3292        drop(encoder_attention_mask);
3293        drop(uncond_hs);
3294        drop(uncond_mask);
3295
3296        // 8. VAE decode
3297        progress.stage_start("VAE decode");
3298        let vae_start = Instant::now();
3299        let free_for_decode = usable_free_vram_bytes(self.base.gpu_ordinal).unwrap_or(0);
3300        let prefer_tiled = Self::should_proactively_tile_vae_decode(
3301            req.width,
3302            req.height,
3303            loaded.vae_device.is_cuda(),
3304            free_for_decode,
3305        );
3306
3307        // Always decode in F32 — matches sequential path and diffusers' force_upcast.
3308        let keep_transformer_hot = Self::can_keep_transformer_hot_for_vae(loaded);
3309        let image = if keep_transformer_hot {
3310            match Self::decode_vae_gpu_only(
3311                &latents,
3312                &loaded.vae,
3313                &loaded.vae_device,
3314                &loaded.device,
3315                progress,
3316                prefer_tiled,
3317            ) {
3318                Ok(image) => {
3319                    progress.info(
3320                        "Kept quantized Qwen transformer resident across VAE decode for faster hot-path reuse",
3321                    );
3322                    image
3323                }
3324                Err(err) if Self::is_oom_error(&err) => {
3325                    loaded.transformer = None;
3326                    loaded.device.synchronize()?;
3327                    progress.info(
3328                        "Dropping Qwen-Image transformer after resident VAE decode OOM and retrying",
3329                    );
3330                    Self::decode_vae_with_fallback(
3331                        &latents,
3332                        &loaded.vae,
3333                        &loaded.vae_device,
3334                        &loaded.device,
3335                        progress,
3336                        prefer_tiled,
3337                        || {
3338                            QwenImageVae::load(&loaded.vae_path, &Device::Cpu, DType::F32, progress)
3339                                .map_err(Into::into)
3340                        },
3341                    )?
3342                }
3343                Err(err) => return Err(err),
3344            }
3345        } else {
3346            loaded.transformer = None;
3347            loaded.device.synchronize()?;
3348            tracing::info!("Qwen-Image transformer dropped to free VRAM for VAE decode");
3349            Self::decode_vae_with_fallback(
3350                &latents,
3351                &loaded.vae,
3352                &loaded.vae_device,
3353                &loaded.device,
3354                progress,
3355                prefer_tiled,
3356                || {
3357                    QwenImageVae::load(&loaded.vae_path, &Device::Cpu, DType::F32, progress)
3358                        .map_err(Into::into)
3359                },
3360            )?
3361        };
3362        Self::validate_qwen_tensor_boundary("image_pre_postprocess", &image)?;
3363        Self::debug_tensor_stats("image_pre_postprocess", &image);
3364        let image = postprocess_image(&image)?;
3365        let post_stats = Self::validate_qwen_tensor_boundary("image_postprocess", &image)?;
3366        Self::debug_tensor_stats("image_postprocess", &image);
3367        let image = image.i(0)?;
3368        if Self::near_black_image_stats(post_stats) {
3369            progress.info(
3370                "Qwen diagnostic: decoded image is near-black after VAE postprocess; inspect MOLD_QWEN_DEBUG tensor stats to separate denoise math from VAE decode",
3371            );
3372            tracing::warn!(
3373                min = post_stats.min,
3374                max = post_stats.max,
3375                mean = post_stats.mean,
3376                "Qwen decoded image is near-black after VAE postprocess"
3377            );
3378        }
3379
3380        progress.stage_done("VAE decode", vae_start.elapsed());
3381
3382        // 9. Encode to output format
3383        let output_metadata = build_output_metadata(req, seed, None);
3384        let image_bytes = encode_image(
3385            &image,
3386            req.resolved_output_format(),
3387            req.width,
3388            req.height,
3389            output_metadata.as_ref(),
3390        )?;
3391
3392        let generation_time_ms = start.elapsed().as_millis() as u64;
3393        tracing::info!(generation_time_ms, seed, "Qwen-Image generation complete");
3394
3395        Ok(GenerateResponse {
3396            images: vec![ImageData {
3397                data: image_bytes,
3398                format: req.resolved_output_format(),
3399                width: req.width,
3400                height: req.height,
3401                index: 0,
3402            }],
3403            generation_time_ms,
3404            model: req.model.clone(),
3405            seed_used: seed,
3406            video: None,
3407            gpu: None,
3408        })
3409    }
3410}
3411
3412impl InferenceEngine for QwenImageEngine {
3413    fn generate(&mut self, req: &GenerateRequest) -> Result<GenerateResponse> {
3414        self.pending_placement = req.placement.clone();
3415        self.pending_loras = effective_loras(req);
3416        let result = self.generate_inner(req);
3417        self.pending_placement = None;
3418        self.pending_loras.clear();
3419        result
3420    }
3421
3422    fn model_name(&self) -> &str {
3423        self.base.model_name()
3424    }
3425
3426    fn is_loaded(&self) -> bool {
3427        self.base.is_loaded()
3428    }
3429
3430    fn load(&mut self) -> Result<()> {
3431        QwenImageEngine::load(self)
3432    }
3433
3434    fn unload(&mut self) {
3435        self.base.unload();
3436        clear_cache(&self.prompt_cache);
3437    }
3438
3439    fn set_on_progress(&mut self, callback: ProgressCallback) {
3440        self.base.set_on_progress(callback);
3441    }
3442
3443    fn clear_on_progress(&mut self) {
3444        self.base.clear_on_progress();
3445    }
3446
3447    fn model_paths(&self) -> Option<&mold_core::ModelPaths> {
3448        Some(&self.base.paths)
3449    }
3450}
3451
3452#[cfg(test)]
3453mod tests {
3454    use super::*;
3455    use crate::engine::LoadStrategy;
3456    use crate::shared_pool::SharedPool;
3457    use candle_core::Shape;
3458    use mold_core::ModelPaths;
3459    use safetensors::tensor::{serialize_to_file, Dtype as SafeDtype, TensorView};
3460    use std::collections::HashMap;
3461    use std::fs;
3462    use std::path::{Path, PathBuf};
3463    use std::sync::{Arc, Mutex};
3464    use std::time::{SystemTime, UNIX_EPOCH};
3465    use tokenizers::models::bpe::BPE;
3466
3467    fn temp_test_dir(prefix: &str) -> PathBuf {
3468        let suffix = SystemTime::now()
3469            .duration_since(UNIX_EPOCH)
3470            .unwrap()
3471            .as_nanos();
3472        let dir = std::env::temp_dir().join(format!("{prefix}-{}-{suffix}", std::process::id()));
3473        fs::create_dir_all(&dir).unwrap();
3474        dir
3475    }
3476
3477    fn touch(dir: &Path, name: &str) -> PathBuf {
3478        let path = dir.join(name);
3479        fs::write(&path, b"test").unwrap();
3480        path
3481    }
3482
3483    fn png_with_dimensions(width: u32, height: u32) -> Vec<u8> {
3484        let img = image::RgbImage::from_fn(width, height, |_, _| image::Rgb([255, 0, 0]));
3485        let mut buf = std::io::Cursor::new(Vec::new());
3486        image::DynamicImage::ImageRgb8(img)
3487            .write_to(&mut buf, image::ImageFormat::Png)
3488            .unwrap();
3489        buf.into_inner()
3490    }
3491
3492    fn qwen_image_model_paths(
3493        transformer: PathBuf,
3494        transformer_shards: Vec<PathBuf>,
3495        vae: PathBuf,
3496        text_tokenizer: Option<PathBuf>,
3497    ) -> ModelPaths {
3498        ModelPaths {
3499            transformer,
3500            transformer_shards,
3501            vae,
3502            spatial_upscaler: None,
3503            temporal_upscaler: None,
3504            distilled_lora: None,
3505            t5_encoder: None,
3506            clip_encoder: None,
3507            t5_tokenizer: None,
3508            clip_tokenizer: None,
3509            clip_encoder_2: None,
3510            clip_tokenizer_2: None,
3511            text_encoder_files: vec![],
3512            text_tokenizer,
3513            decoder: None,
3514        }
3515    }
3516
3517    fn resolved_text_encoder(is_gguf: bool, auto_use_gpu: bool) -> ResolvedQwen2TextEncoder {
3518        ResolvedQwen2TextEncoder {
3519            paths: vec![],
3520            vision_paths: vec![],
3521            is_gguf,
3522            variant_label: if is_gguf {
3523                "q6".to_string()
3524            } else {
3525                "bf16".to_string()
3526            },
3527            size_bytes: 0,
3528            auto_use_gpu,
3529        }
3530    }
3531
3532    fn tensor_values_u8(t: &Tensor) -> Vec<u8> {
3533        t.flatten_all()
3534            .unwrap()
3535            .to_vec1::<u8>()
3536            .expect("u8 tensor values")
3537    }
3538
3539    fn tensor_values_f32(t: &Tensor) -> Vec<f32> {
3540        t.flatten_all()
3541            .unwrap()
3542            .to_vec1::<f32>()
3543            .expect("f32 tensor values")
3544    }
3545
3546    #[test]
3547    fn safetensors_is_fp8_uses_filename_hint() {
3548        assert!(safetensors_is_fp8(Path::new(
3549            "/tmp/qwen-image-fp8.safetensors"
3550        )));
3551        assert!(!safetensors_is_fp8(Path::new(
3552            "/tmp/qwen-image.safetensors"
3553        )));
3554    }
3555
3556    #[test]
3557    fn text_encoder_is_fp8_uses_filename_hint() {
3558        assert!(text_encoder_is_fp8(&[PathBuf::from(
3559            "/tmp/qwen2-text-encoder-fp8-00001-of-00002.safetensors"
3560        )]));
3561        assert!(!text_encoder_is_fp8(&[PathBuf::from(
3562            "/tmp/qwen2-text-encoder-00001-of-00002.safetensors"
3563        )]));
3564    }
3565
3566    #[test]
3567    fn cached_prompt_conditioning_roundtrips_and_restores_mask() {
3568        let device = Device::Cpu;
3569        let hidden_states = Tensor::from_vec(
3570            vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0],
3571            Shape::from((1, 3, 2)),
3572            &device,
3573        )
3574        .unwrap();
3575        let cached = CachedPromptConditioning::from_parts(&hidden_states, 2).unwrap();
3576
3577        let (restored_hs, restored_mask) = cached.restore(&device, DType::F32).unwrap();
3578
3579        assert_eq!(
3580            tensor_values_f32(&restored_hs),
3581            tensor_values_f32(&hidden_states)
3582        );
3583        assert_eq!(tensor_values_u8(&restored_mask), vec![1, 1, 0]);
3584    }
3585
3586    #[test]
3587    fn pad_text_conditioning_keeps_original_when_target_matches() {
3588        let device = Device::Cpu;
3589        let hidden_states =
3590            Tensor::from_vec(vec![1.0f32, 2.0, 3.0, 4.0], Shape::from((1, 2, 2)), &device).unwrap();
3591        let mask = Tensor::from_vec(vec![1u8, 1], Shape::from((1, 2)), &device).unwrap();
3592
3593        let (padded_hs, padded_mask) = pad_text_conditioning(&hidden_states, &mask, 2).unwrap();
3594
3595        assert_eq!(
3596            tensor_values_f32(&padded_hs),
3597            tensor_values_f32(&hidden_states)
3598        );
3599        assert_eq!(tensor_values_u8(&padded_mask), vec![1, 1]);
3600    }
3601
3602    #[test]
3603    fn pad_text_conditioning_appends_zero_padding() {
3604        let device = Device::Cpu;
3605        let hidden_states =
3606            Tensor::from_vec(vec![1.0f32, 2.0, 3.0, 4.0], Shape::from((1, 2, 2)), &device).unwrap();
3607        let mask = Tensor::from_vec(vec![1u8, 0], Shape::from((1, 2)), &device).unwrap();
3608
3609        let (padded_hs, padded_mask) = pad_text_conditioning(&hidden_states, &mask, 4).unwrap();
3610
3611        assert_eq!(padded_hs.dims3().unwrap(), (1, 4, 2));
3612        assert_eq!(
3613            tensor_values_f32(&padded_hs),
3614            vec![1.0, 2.0, 3.0, 4.0, 0.0, 0.0, 0.0, 0.0]
3615        );
3616        assert_eq!(tensor_values_u8(&padded_mask), vec![1, 0, 0, 0]);
3617    }
3618
3619    #[test]
3620    fn pad_text_conditioning_rejects_shrinking() {
3621        let device = Device::Cpu;
3622        let hidden_states =
3623            Tensor::from_vec(vec![1.0f32, 2.0, 3.0, 4.0], Shape::from((1, 2, 2)), &device).unwrap();
3624        let mask = Tensor::from_vec(vec![1u8, 1], Shape::from((1, 2)), &device).unwrap();
3625
3626        let err = pad_text_conditioning(&hidden_states, &mask, 1).unwrap_err();
3627        assert!(err.to_string().contains("cannot shrink text conditioning"));
3628    }
3629
3630    #[test]
3631    fn align_cfg_conditioning_pads_shorter_branch_to_match_longer_one() {
3632        let device = Device::Cpu;
3633        let cond_hs = Tensor::from_vec(
3634            vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0],
3635            Shape::from((1, 3, 2)),
3636            &device,
3637        )
3638        .unwrap();
3639        let cond_mask = Tensor::from_vec(vec![1u8, 1, 1], Shape::from((1, 3)), &device).unwrap();
3640        let uncond_hs = Tensor::from_vec(
3641            vec![7.0f32, 8.0, 9.0, 10.0],
3642            Shape::from((1, 2, 2)),
3643            &device,
3644        )
3645        .unwrap();
3646        let uncond_mask = Tensor::from_vec(vec![1u8, 0], Shape::from((1, 2)), &device).unwrap();
3647
3648        let ((cond_hs, cond_mask), (uncond_hs, uncond_mask)) =
3649            align_cfg_conditioning(&cond_hs, &cond_mask, &uncond_hs, &uncond_mask).unwrap();
3650
3651        assert_eq!(cond_hs.dims3().unwrap(), (1, 3, 2));
3652        assert_eq!(uncond_hs.dims3().unwrap(), (1, 3, 2));
3653        assert_eq!(tensor_values_u8(&cond_mask), vec![1, 1, 1]);
3654        assert_eq!(tensor_values_u8(&uncond_mask), vec![1, 0, 0]);
3655        assert_eq!(
3656            tensor_values_f32(&uncond_hs),
3657            vec![7.0, 8.0, 9.0, 10.0, 0.0, 0.0]
3658        );
3659    }
3660
3661    #[test]
3662    fn qwen_image_detects_gguf_transformer() {
3663        let engine = QwenImageEngine::new(
3664            "qwen-image:q4".to_string(),
3665            ModelPaths {
3666                transformer: PathBuf::from("/tmp/qwen-image-Q4_K_S.gguf"),
3667                transformer_shards: vec![],
3668                vae: PathBuf::from("/tmp/vae.safetensors"),
3669                spatial_upscaler: None,
3670                temporal_upscaler: None,
3671                distilled_lora: None,
3672                t5_encoder: None,
3673                clip_encoder: None,
3674                t5_tokenizer: None,
3675                clip_tokenizer: None,
3676                clip_encoder_2: None,
3677                clip_tokenizer_2: None,
3678                text_encoder_files: vec![],
3679                text_tokenizer: Some(PathBuf::from("/tmp/tokenizer.json")),
3680                decoder: None,
3681            },
3682            LoadStrategy::Sequential,
3683            0,
3684            false,
3685            None,
3686        );
3687
3688        assert!(engine.detect_is_quantized());
3689    }
3690
3691    #[test]
3692    fn qwen_image_text_encoder_uses_gpu_on_metal() {
3693        let plan = QwenImageEngine::qwen2_text_encoder_plan_for_mode(
3694            Qwen2TextEncoderMode::Auto,
3695            false,
3696            true,
3697            &resolved_text_encoder(true, true),
3698        );
3699        assert!(plan.use_gpu);
3700        assert!(!plan.use_cpu_staging);
3701    }
3702
3703    #[test]
3704    fn qwen_image_text_encoder_uses_gpu_on_cuda_with_headroom() {
3705        let plan = QwenImageEngine::qwen2_text_encoder_plan_for_mode(
3706            Qwen2TextEncoderMode::Auto,
3707            true,
3708            false,
3709            &resolved_text_encoder(false, true),
3710        );
3711        assert!(plan.use_gpu);
3712        assert!(!plan.use_cpu_staging);
3713    }
3714
3715    #[test]
3716    fn qwen_image_text_encoder_uses_cpu_on_cuda_without_headroom() {
3717        let plan = QwenImageEngine::qwen2_text_encoder_plan_for_mode(
3718            Qwen2TextEncoderMode::Auto,
3719            true,
3720            false,
3721            &resolved_text_encoder(false, false),
3722        );
3723        assert!(!plan.use_gpu);
3724        assert!(!plan.use_cpu_staging);
3725    }
3726
3727    #[test]
3728    fn qwen_image_cpu_safetensors_text_encoder_stays_f32() {
3729        assert_eq!(
3730            QwenImageEngine::text_encoder_load_dtype(false, DType::BF16),
3731            DType::F32
3732        );
3733    }
3734
3735    #[test]
3736    fn qwen_image_cpu_gguf_text_encoder_stays_f32() {
3737        assert_eq!(
3738            QwenImageEngine::text_encoder_load_dtype(false, DType::BF16),
3739            DType::F32
3740        );
3741    }
3742
3743    #[test]
3744    fn qwen_image_text_encoder_gpu_override_disables_metal_staging() {
3745        let plan = QwenImageEngine::qwen2_text_encoder_plan_for_mode(
3746            Qwen2TextEncoderMode::Gpu,
3747            false,
3748            true,
3749            &resolved_text_encoder(true, true),
3750        );
3751        assert!(plan.use_gpu);
3752        assert!(!plan.use_cpu_staging);
3753    }
3754
3755    #[test]
3756    fn qwen_image_auto_prefers_q6_on_metal_with_headroom() {
3757        let q6 = mold_core::manifest::find_qwen2_vl_variant("q6").unwrap();
3758        let resolved = QwenImageEngine::choose_text_encoder_source(
3759            Some("auto"),
3760            false,
3761            true,
3762            qwen2_vram_threshold(q6.size_bytes) + 1,
3763            16_600_000_000,
3764            Qwen2TextEncoderUsage::Resident,
3765        )
3766        .unwrap();
3767        assert!(resolved.is_gguf);
3768        assert_eq!(resolved.variant_label, "q6");
3769        assert!(resolved.auto_use_gpu);
3770    }
3771
3772    #[test]
3773    fn qwen_image_auto_falls_back_to_q4_on_metal_when_q6_does_not_fit() {
3774        let q4 = mold_core::manifest::find_qwen2_vl_variant("q4").unwrap();
3775        let q6 = mold_core::manifest::find_qwen2_vl_variant("q6").unwrap();
3776        let free_vram = qwen2_vram_threshold(q4.size_bytes);
3777        assert!(free_vram < qwen2_vram_threshold(q6.size_bytes));
3778
3779        let resolved = QwenImageEngine::choose_text_encoder_source(
3780            Some("auto"),
3781            false,
3782            true,
3783            free_vram,
3784            0,
3785            Qwen2TextEncoderUsage::Resident,
3786        )
3787        .unwrap();
3788        assert!(resolved.is_gguf);
3789        assert_eq!(resolved.variant_label, "q4");
3790        assert!(resolved.auto_use_gpu);
3791    }
3792
3793    #[test]
3794    fn qwen_image_auto_keeps_bf16_default_on_cuda() {
3795        let resolved = QwenImageEngine::choose_text_encoder_source(
3796            Some("auto"),
3797            true,
3798            false,
3799            QWEN2_FP16_VRAM_THRESHOLD + 1,
3800            16_600_000_000,
3801            Qwen2TextEncoderUsage::Resident,
3802        )
3803        .unwrap();
3804        assert!(!resolved.is_gguf);
3805        assert_eq!(resolved.variant_label, "bf16");
3806        assert!(resolved.auto_use_gpu);
3807    }
3808
3809    #[test]
3810    fn qwen_image_auto_prefers_quantized_gpu_on_cuda_for_resident_mode_when_it_fits() {
3811        let resolved = QwenImageEngine::choose_text_encoder_source(
3812            Some("auto"),
3813            true,
3814            false,
3815            QWEN2_FP16_VRAM_THRESHOLD - 1,
3816            16_600_000_000,
3817            Qwen2TextEncoderUsage::Resident,
3818        )
3819        .unwrap();
3820        assert!(resolved.is_gguf);
3821        assert_eq!(resolved.variant_label, "q4");
3822        assert!(resolved.auto_use_gpu);
3823    }
3824
3825    #[test]
3826    fn qwen_image_auto_uses_quantized_cpu_fallback_on_cuda_for_resident_mode() {
3827        let resolved = QwenImageEngine::choose_text_encoder_source(
3828            Some("auto"),
3829            true,
3830            false,
3831            1,
3832            16_600_000_000,
3833            Qwen2TextEncoderUsage::Resident,
3834        )
3835        .unwrap();
3836        assert!(resolved.is_gguf);
3837        assert_eq!(resolved.variant_label, "q4");
3838        assert!(!resolved.auto_use_gpu);
3839    }
3840
3841    #[test]
3842    fn qwen_image_auto_prefers_quantized_gpu_on_cuda_for_sequential_mode_when_it_fits() {
3843        let resolved = QwenImageEngine::choose_text_encoder_source(
3844            Some("auto"),
3845            true,
3846            false,
3847            QWEN2_FP16_VRAM_THRESHOLD - 1,
3848            16_600_000_000,
3849            Qwen2TextEncoderUsage::Sequential,
3850        )
3851        .unwrap();
3852        assert!(resolved.is_gguf);
3853        assert_eq!(resolved.variant_label, "q4");
3854        assert!(resolved.auto_use_gpu);
3855    }
3856
3857    #[test]
3858    fn qwen_image_auto_uses_quantized_cpu_fallback_on_cuda_for_sequential_mode() {
3859        let resolved = QwenImageEngine::choose_text_encoder_source(
3860            Some("auto"),
3861            true,
3862            false,
3863            1,
3864            16_600_000_000,
3865            Qwen2TextEncoderUsage::Sequential,
3866        )
3867        .unwrap();
3868        assert!(resolved.is_gguf);
3869        assert_eq!(resolved.variant_label, "q4");
3870        assert!(!resolved.auto_use_gpu);
3871    }
3872
3873    #[test]
3874    fn qwen_image_explicit_q6_respects_cpu_fallback_on_cuda() {
3875        let resolved = QwenImageEngine::choose_text_encoder_source(
3876            Some("q6"),
3877            true,
3878            false,
3879            1,
3880            0,
3881            Qwen2TextEncoderUsage::Resident,
3882        )
3883        .unwrap();
3884        assert!(resolved.is_gguf);
3885        assert_eq!(resolved.variant_label, "q6");
3886        assert!(!resolved.auto_use_gpu);
3887    }
3888
3889    #[test]
3890    fn qwen_image_edit_accepts_quantized_text_with_bf16_vision_sidecar() {
3891        let dir = temp_test_dir("qwen-image-edit-text-encoder");
3892        let transformer = touch(&dir, "qwen-image-edit.gguf");
3893        let vae = touch(&dir, "vae.safetensors");
3894        let tokenizer = touch(&dir, "tokenizer.json");
3895        let mut paths = qwen_image_model_paths(transformer, vec![], vae, Some(tokenizer));
3896        paths.text_encoder_files = vec![touch(&dir, "text-encoder-00001-of-00004.safetensors")];
3897        let engine = QwenImageEngine::new(
3898            "qwen-image-edit-2511:q4".to_string(),
3899            paths,
3900            LoadStrategy::Sequential,
3901            0,
3902            false,
3903            None,
3904        );
3905
3906        let resolved = engine
3907            .resolve_text_encoder_source_with_preference(
3908                &Device::Cpu,
3909                0,
3910                Qwen2TextEncoderUsage::Sequential,
3911                Some("auto"),
3912            )
3913            .unwrap();
3914        assert!(!resolved.vision_paths.is_empty());
3915
3916        let resolved = engine
3917            .resolve_text_encoder_source_with_preference(
3918                &Device::Cpu,
3919                0,
3920                Qwen2TextEncoderUsage::Sequential,
3921                Some("q4"),
3922            )
3923            .unwrap();
3924        assert!(resolved.is_gguf);
3925        assert_eq!(resolved.variant_label, "q4");
3926        assert_eq!(resolved.vision_paths.len(), 1);
3927
3928        let resolved = engine
3929            .resolve_text_encoder_source_with_preference(
3930                &Device::Cpu,
3931                0,
3932                Qwen2TextEncoderUsage::Sequential,
3933                Some("bf16"),
3934            )
3935            .unwrap();
3936        assert!(!resolved.is_gguf);
3937        assert_eq!(resolved.variant_label, "bf16");
3938        assert_eq!(resolved.vision_paths.len(), 1);
3939    }
3940
3941    #[test]
3942    fn qwen_image_edit_prompt_numbers_each_picture_placeholder() {
3943        let prompt = QwenImageEngine::qwen_image_edit_prompt("swap materials", 3);
3944        assert!(prompt.contains(QWEN_IMAGE_EDIT_SYSTEM_PROMPT));
3945        assert!(prompt.contains("Picture 1: <|vision_start|><|image_pad|><|vision_end|>"));
3946        assert!(prompt.contains("Picture 2: <|vision_start|><|image_pad|><|vision_end|>"));
3947        assert!(prompt.contains("Picture 3: <|vision_start|><|image_pad|><|vision_end|>"));
3948        assert!(prompt.ends_with("<|im_start|>assistant\n"));
3949    }
3950
3951    #[test]
3952    fn qwen_image_edit_image_dims_fit_target_area_with_16px_alignment() {
3953        let bytes = png_with_dimensions(1600, 900);
3954        let (width, height) =
3955            QwenImageEngine::qwen_image_edit_image_dims(&bytes, QWEN_IMAGE_EDIT_VAE_AREA).unwrap();
3956        assert_eq!((width, height), (1360, 768));
3957        assert_eq!(width % 16, 0);
3958        assert_eq!(height % 16, 0);
3959    }
3960
3961    #[test]
3962    fn pack_and_unpack_latents_roundtrip() {
3963        let values: Vec<f32> = (0..(16 * 4 * 6)).map(|i| i as f32).collect();
3964        let latents = Tensor::from_vec(values.clone(), (1, 16, 4, 6), &Device::Cpu).unwrap();
3965        let packed = QwenImageEngine::pack_latents_4d(&latents).unwrap();
3966        assert_eq!(packed.dims3().unwrap(), (1, 6, 64));
3967
3968        let unpacked = QwenImageEngine::unpack_latents_packed(&packed, 4, 6).unwrap();
3969        assert_eq!(
3970            unpacked.flatten_all().unwrap().to_vec1::<f32>().unwrap(),
3971            values
3972        );
3973    }
3974
3975    #[test]
3976    fn quantized_cuda_cfg_headroom_scales_with_resolution() {
3977        let native = QwenImageEngine::quantized_cuda_cfg_headroom(1328, 1328);
3978        let reduced = QwenImageEngine::quantized_cuda_cfg_headroom(512, 512);
3979        assert_eq!(native, QWEN_GGUF_NATIVE_CFG_HEADROOM);
3980        assert_eq!(reduced, QWEN_GGUF_MIN_CFG_HEADROOM);
3981    }
3982
3983    #[test]
3984    fn qwen_quantized_native_resolution_uses_split_cfg_on_24gb_cuda() {
3985        assert!(QwenImageEngine::should_split_cfg_quantized_cuda(
3986            12_300_000_000,
3987            24_600_000_000,
3988            1328,
3989            1328,
3990        ));
3991    }
3992
3993    #[test]
3994    fn qwen_quantized_reduced_resolution_keeps_batched_cfg_when_it_fits() {
3995        assert!(!QwenImageEngine::should_split_cfg_quantized_cuda(
3996            12_300_000_000,
3997            24_600_000_000,
3998            512,
3999            512,
4000        ));
4001    }
4002
4003    #[test]
4004    fn qwen_quantized_cfg_split_boundary_does_not_split_when_estimate_exactly_fits() {
4005        let headroom = QwenImageEngine::quantized_cuda_cfg_headroom(1328, 1328);
4006        let transformer_size = 12_300_000_000;
4007        let free_vram = transformer_size + headroom;
4008        assert!(!QwenImageEngine::should_split_cfg_quantized_cuda(
4009            transformer_size,
4010            free_vram,
4011            1328,
4012            1328,
4013        ));
4014    }
4015
4016    #[test]
4017    fn qwen_quantized_unknown_vram_biases_to_split_cfg() {
4018        assert!(QwenImageEngine::should_split_cfg_quantized_cuda(
4019            12_300_000_000,
4020            0,
4021            1328,
4022            1328,
4023        ));
4024    }
4025
4026    #[test]
4027    fn qwen_is_oom_error_matches_cuda_memory_allocation_string() {
4028        assert!(QwenImageEngine::is_oom_error(&"cudaErrorMemoryAllocation"));
4029    }
4030
4031    #[test]
4032    fn qwen_debug_stats_counts_nan_and_inf() {
4033        let tensor = Tensor::from_vec(
4034            vec![0.0f32, 1.0, f32::NAN, f32::INFINITY, f32::NEG_INFINITY],
4035            Shape::from((5,)),
4036            &Device::Cpu,
4037        )
4038        .unwrap();
4039
4040        let stats = QwenImageEngine::tensor_stats(&tensor).unwrap();
4041
4042        assert_eq!(stats.total, 5);
4043        assert_eq!(stats.nan_count, 1);
4044        assert_eq!(stats.pos_inf_count, 1);
4045        assert_eq!(stats.neg_inf_count, 1);
4046        assert_eq!(stats.min, 0.0);
4047        assert_eq!(stats.max, 1.0);
4048        assert_eq!(stats.mean, 0.5);
4049    }
4050
4051    #[test]
4052    fn qwen_debug_stats_detects_near_black_postprocessed_image() {
4053        let stats = QwenTensorStats {
4054            min: 0.0,
4055            max: 0.01,
4056            mean: 0.004,
4057            nan_count: 0,
4058            pos_inf_count: 0,
4059            neg_inf_count: 0,
4060            total: 1024,
4061        };
4062
4063        assert!(QwenImageEngine::near_black_image_stats(stats));
4064    }
4065
4066    #[test]
4067    fn qwen_debug_stats_does_not_flag_non_black_image() {
4068        let stats = QwenTensorStats {
4069            min: 0.0,
4070            max: 0.75,
4071            mean: 0.18,
4072            nan_count: 0,
4073            pos_inf_count: 0,
4074            neg_inf_count: 0,
4075            total: 1024,
4076        };
4077
4078        assert!(!QwenImageEngine::near_black_image_stats(stats));
4079    }
4080
4081    #[test]
4082    fn qwen_debug_stats_formats_progress_message() {
4083        let stats = QwenTensorStats {
4084            min: 0.0,
4085            max: 1.0,
4086            mean: 0.5,
4087            nan_count: 2,
4088            pos_inf_count: 1,
4089            neg_inf_count: 1,
4090            total: 10,
4091        };
4092
4093        let message = QwenImageEngine::format_tensor_stats("sample", stats);
4094
4095        assert!(message.contains("NaN=2/10"));
4096        assert!(message.contains("+Inf=1"));
4097        assert!(message.contains("-Inf=1"));
4098    }
4099
4100    #[test]
4101    fn qwen_oom_fallback_returns_primary_success_without_running_fallback() {
4102        let mut progress = ProgressReporter::default();
4103        let messages = std::sync::Arc::new(std::sync::Mutex::new(Vec::<String>::new()));
4104        let messages_clone = messages.clone();
4105        progress.set_callback(Box::new(move |event| {
4106            if let ProgressEvent::Info { message } = event {
4107                messages_clone.lock().unwrap().push(message);
4108            }
4109        }));
4110
4111        let fallback_called = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false));
4112        let fallback_called_clone = fallback_called.clone();
4113        let value = QwenImageEngine::with_cuda_oom_cpu_fallback(
4114            || Ok(7usize),
4115            || {
4116                fallback_called_clone.store(true, std::sync::atomic::Ordering::SeqCst);
4117                Ok(9usize)
4118            },
4119            true,
4120            &Device::Cpu,
4121            &progress,
4122            "retrying",
4123            |_| true,
4124        )
4125        .unwrap();
4126
4127        assert_eq!(value, 7);
4128        assert!(!fallback_called.load(std::sync::atomic::Ordering::SeqCst));
4129        assert!(messages.lock().unwrap().is_empty());
4130    }
4131
4132    #[test]
4133    fn qwen_oom_fallback_retries_when_primary_ooms_on_cuda() {
4134        let mut progress = ProgressReporter::default();
4135        let messages = std::sync::Arc::new(std::sync::Mutex::new(Vec::<String>::new()));
4136        let messages_clone = messages.clone();
4137        progress.set_callback(Box::new(move |event| {
4138            if let ProgressEvent::Info { message } = event {
4139                messages_clone.lock().unwrap().push(message);
4140            }
4141        }));
4142
4143        let value = QwenImageEngine::with_cuda_oom_cpu_fallback(
4144            || Err(anyhow::anyhow!("cudaErrorMemoryAllocation")),
4145            || Ok(11usize),
4146            true,
4147            &Device::Cpu,
4148            &progress,
4149            "retrying",
4150            QwenImageEngine::is_oom_error,
4151        )
4152        .unwrap();
4153
4154        assert_eq!(value, 11);
4155        assert_eq!(messages.lock().unwrap().as_slice(), ["retrying"]);
4156    }
4157
4158    #[test]
4159    fn qwen_oom_fallback_does_not_retry_non_oom_errors() {
4160        let progress = ProgressReporter::default();
4161        let err = QwenImageEngine::with_cuda_oom_cpu_fallback(
4162            || Err(anyhow::anyhow!("not an oom")),
4163            || Ok(11usize),
4164            true,
4165            &Device::Cpu,
4166            &progress,
4167            "retrying",
4168            QwenImageEngine::is_oom_error,
4169        )
4170        .unwrap_err();
4171
4172        assert!(err.to_string().contains("not an oom"));
4173    }
4174
4175    #[test]
4176    fn qwen_tiled_fallback_returns_primary_success_without_retrying() {
4177        let progress = ProgressReporter::default();
4178        let tiled_called = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false));
4179        let cpu_called = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false));
4180        let tiled_called_clone = tiled_called.clone();
4181        let cpu_called_clone = cpu_called.clone();
4182
4183        let value = QwenImageEngine::with_cuda_tiled_then_cpu_fallback(
4184            || Ok(5usize),
4185            || {
4186                tiled_called_clone.store(true, std::sync::atomic::Ordering::SeqCst);
4187                Ok(7usize)
4188            },
4189            || {
4190                cpu_called_clone.store(true, std::sync::atomic::Ordering::SeqCst);
4191                Ok(9usize)
4192            },
4193            true,
4194            false,
4195            &Device::Cpu,
4196            &progress,
4197            "tiled",
4198            "cpu",
4199            |_| true,
4200        )
4201        .unwrap();
4202
4203        assert_eq!(value, 5);
4204        assert!(!tiled_called.load(std::sync::atomic::Ordering::SeqCst));
4205        assert!(!cpu_called.load(std::sync::atomic::Ordering::SeqCst));
4206    }
4207
4208    #[test]
4209    fn qwen_tiled_fallback_uses_tiled_result_before_cpu() {
4210        let mut progress = ProgressReporter::default();
4211        let messages = std::sync::Arc::new(std::sync::Mutex::new(Vec::<String>::new()));
4212        let messages_clone = messages.clone();
4213        progress.set_callback(Box::new(move |event| {
4214            if let ProgressEvent::Info { message } = event {
4215                messages_clone.lock().unwrap().push(message);
4216            }
4217        }));
4218
4219        let cpu_called = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false));
4220        let cpu_called_clone = cpu_called.clone();
4221        let value = QwenImageEngine::with_cuda_tiled_then_cpu_fallback(
4222            || Err(anyhow::anyhow!("out of memory")),
4223            || Ok(13usize),
4224            || {
4225                cpu_called_clone.store(true, std::sync::atomic::Ordering::SeqCst);
4226                Ok(17usize)
4227            },
4228            true,
4229            false,
4230            &Device::Cpu,
4231            &progress,
4232            "tiled",
4233            "cpu",
4234            QwenImageEngine::is_oom_error,
4235        )
4236        .unwrap();
4237
4238        assert_eq!(value, 13);
4239        assert!(!cpu_called.load(std::sync::atomic::Ordering::SeqCst));
4240        assert_eq!(messages.lock().unwrap().as_slice(), ["tiled"]);
4241    }
4242
4243    #[test]
4244    fn qwen_tiled_fallback_uses_cpu_after_tiled_oom() {
4245        let mut progress = ProgressReporter::default();
4246        let messages = std::sync::Arc::new(std::sync::Mutex::new(Vec::<String>::new()));
4247        let messages_clone = messages.clone();
4248        progress.set_callback(Box::new(move |event| {
4249            if let ProgressEvent::Info { message } = event {
4250                messages_clone.lock().unwrap().push(message);
4251            }
4252        }));
4253
4254        let value = QwenImageEngine::with_cuda_tiled_then_cpu_fallback(
4255            || Err(anyhow::anyhow!("OUT_OF_MEMORY")),
4256            || Err(anyhow::anyhow!("OUT_OF_MEMORY")),
4257            || Ok(19usize),
4258            true,
4259            false,
4260            &Device::Cpu,
4261            &progress,
4262            "tiled",
4263            "cpu",
4264            QwenImageEngine::is_oom_error,
4265        )
4266        .unwrap();
4267
4268        assert_eq!(value, 19);
4269        assert_eq!(messages.lock().unwrap().as_slice(), ["tiled", "cpu"]);
4270    }
4271
4272    #[test]
4273    fn qwen_tiled_fallback_propagates_non_oom_tiled_error() {
4274        let progress = ProgressReporter::default();
4275        let err = QwenImageEngine::with_cuda_tiled_then_cpu_fallback(
4276            || Err(anyhow::anyhow!("out of memory")),
4277            || Err(anyhow::anyhow!("bad tiled decode")),
4278            || Ok(19usize),
4279            true,
4280            false,
4281            &Device::Cpu,
4282            &progress,
4283            "tiled",
4284            "cpu",
4285            QwenImageEngine::is_oom_error,
4286        )
4287        .unwrap_err();
4288
4289        assert!(err.to_string().contains("bad tiled decode"));
4290    }
4291
4292    #[test]
4293    fn qwen_proactive_tiled_policy_selects_native_cuda_under_pressure() {
4294        assert!(QwenImageEngine::should_proactively_tile_vae_decode(
4295            1328,
4296            1328,
4297            true,
4298            6_000_000_000
4299        ));
4300        assert!(!QwenImageEngine::should_proactively_tile_vae_decode(
4301            512,
4302            512,
4303            true,
4304            6_000_000_000
4305        ));
4306        assert!(!QwenImageEngine::should_proactively_tile_vae_decode(
4307            1328,
4308            1328,
4309            false,
4310            6_000_000_000
4311        ));
4312        assert!(!QwenImageEngine::should_proactively_tile_vae_decode(
4313            1328,
4314            1328,
4315            true,
4316            16_000_000_000
4317        ));
4318    }
4319
4320    #[test]
4321    fn qwen_proactive_tiled_decode_skips_primary_full_decode() {
4322        let mut progress = ProgressReporter::default();
4323        let messages = std::sync::Arc::new(std::sync::Mutex::new(Vec::<String>::new()));
4324        let messages_clone = messages.clone();
4325        progress.set_callback(Box::new(move |event| {
4326            if let ProgressEvent::Info { message } = event {
4327                messages_clone.lock().unwrap().push(message);
4328            }
4329        }));
4330
4331        let primary_called = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false));
4332        let primary_called_clone = primary_called.clone();
4333        let value = QwenImageEngine::with_cuda_tiled_then_cpu_fallback(
4334            || {
4335                primary_called_clone.store(true, std::sync::atomic::Ordering::SeqCst);
4336                Ok(3usize)
4337            },
4338            || Ok(7usize),
4339            || Ok(9usize),
4340            true,
4341            true,
4342            &Device::Cpu,
4343            &progress,
4344            "tiled after oom",
4345            "cpu",
4346            QwenImageEngine::is_oom_error,
4347        )
4348        .unwrap();
4349
4350        assert_eq!(value, 7);
4351        assert!(!primary_called.load(std::sync::atomic::Ordering::SeqCst));
4352        assert_eq!(
4353            messages.lock().unwrap().as_slice(),
4354            ["Selecting tiled GPU VAE decode proactively"]
4355        );
4356    }
4357
4358    #[test]
4359    fn qwen_hot_text_encoder_keeps_gpu_after_cache_miss_with_headroom() {
4360        let action = QwenImageEngine::qwen2_text_encoder_post_encode_action(
4361            Qwen2TextEncoderResidencyInput {
4362                on_gpu: true,
4363                is_quantized: true,
4364                is_metal: false,
4365                keep_te_ram: false,
4366                prompt_cache_miss: true,
4367                transformer_resident: true,
4368                free_vram_bytes: 10_000_000_000,
4369                required_vram_bytes: 8_000_000_000,
4370            },
4371        );
4372
4373        assert_eq!(action, Qwen2TextEncoderPostEncodeAction::KeepGpu);
4374    }
4375
4376    #[test]
4377    fn qwen_hot_text_encoder_drops_after_cache_hit_even_with_headroom() {
4378        let action = QwenImageEngine::qwen2_text_encoder_post_encode_action(
4379            Qwen2TextEncoderResidencyInput {
4380                on_gpu: true,
4381                is_quantized: true,
4382                is_metal: false,
4383                keep_te_ram: false,
4384                prompt_cache_miss: false,
4385                transformer_resident: true,
4386                free_vram_bytes: 10_000_000_000,
4387                required_vram_bytes: 8_000_000_000,
4388            },
4389        );
4390
4391        assert_eq!(action, Qwen2TextEncoderPostEncodeAction::Drop);
4392    }
4393
4394    #[test]
4395    fn qwen_hot_text_encoder_drops_under_transformer_pressure() {
4396        let action = QwenImageEngine::qwen2_text_encoder_post_encode_action(
4397            Qwen2TextEncoderResidencyInput {
4398                on_gpu: true,
4399                is_quantized: true,
4400                is_metal: false,
4401                keep_te_ram: false,
4402                prompt_cache_miss: true,
4403                transformer_resident: true,
4404                free_vram_bytes: 7_999_999_999,
4405                required_vram_bytes: 8_000_000_000,
4406            },
4407        );
4408
4409        assert_eq!(action, Qwen2TextEncoderPostEncodeAction::Drop);
4410    }
4411
4412    #[test]
4413    fn qwen_hot_text_encoder_parks_bf16_when_keep_ram_enabled() {
4414        let action = QwenImageEngine::qwen2_text_encoder_post_encode_action(
4415            Qwen2TextEncoderResidencyInput {
4416                on_gpu: true,
4417                is_quantized: false,
4418                is_metal: false,
4419                keep_te_ram: true,
4420                prompt_cache_miss: true,
4421                transformer_resident: true,
4422                free_vram_bytes: 7_999_999_999,
4423                required_vram_bytes: 8_000_000_000,
4424            },
4425        );
4426
4427        assert_eq!(action, Qwen2TextEncoderPostEncodeAction::ParkCpu);
4428    }
4429
4430    #[test]
4431    fn qwen_hot_text_encoder_never_parks_quantized() {
4432        let action = QwenImageEngine::qwen2_text_encoder_post_encode_action(
4433            Qwen2TextEncoderResidencyInput {
4434                on_gpu: true,
4435                is_quantized: true,
4436                is_metal: false,
4437                keep_te_ram: true,
4438                prompt_cache_miss: true,
4439                transformer_resident: true,
4440                free_vram_bytes: 7_999_999_999,
4441                required_vram_bytes: 8_000_000_000,
4442            },
4443        );
4444
4445        assert_eq!(action, Qwen2TextEncoderPostEncodeAction::Drop);
4446    }
4447
4448    #[test]
4449    fn qwen_hot_text_encoder_drops_when_transformer_not_resident() {
4450        let action = QwenImageEngine::qwen2_text_encoder_post_encode_action(
4451            Qwen2TextEncoderResidencyInput {
4452                on_gpu: true,
4453                is_quantized: true,
4454                is_metal: false,
4455                keep_te_ram: false,
4456                prompt_cache_miss: true,
4457                transformer_resident: false,
4458                free_vram_bytes: 10_000_000_000,
4459                required_vram_bytes: 8_000_000_000,
4460            },
4461        );
4462
4463        assert_eq!(action, Qwen2TextEncoderPostEncodeAction::Drop);
4464    }
4465
4466    #[test]
4467    fn qwen_transformer_hot_vae_eligibility_requires_quantized_cuda_components() {
4468        assert!(QwenImageEngine::qwen_transformer_can_stay_hot_for_vae(
4469            true, true, true
4470        ));
4471        assert!(!QwenImageEngine::qwen_transformer_can_stay_hot_for_vae(
4472            false, true, true
4473        ));
4474        assert!(!QwenImageEngine::qwen_transformer_can_stay_hot_for_vae(
4475            true, false, true
4476        ));
4477        assert!(!QwenImageEngine::qwen_transformer_can_stay_hot_for_vae(
4478            true, true, false
4479        ));
4480    }
4481
4482    #[test]
4483    fn qwen_transformer_paths_prefer_shards_when_present() {
4484        let dir = temp_test_dir("mold-qwen-shards");
4485        let shard_a = touch(&dir, "transformer-00001-of-00002.safetensors");
4486        let shard_b = touch(&dir, "transformer-00002-of-00002.safetensors");
4487        let engine = QwenImageEngine::new(
4488            "qwen-image:q4".to_string(),
4489            qwen_image_model_paths(
4490                dir.join("transformer.safetensors"),
4491                vec![shard_a.clone(), shard_b.clone()],
4492                dir.join("vae.safetensors"),
4493                Some(dir.join("tokenizer.json")),
4494            ),
4495            LoadStrategy::Sequential,
4496            0,
4497            false,
4498            None,
4499        );
4500
4501        assert_eq!(engine.transformer_paths(), vec![shard_a, shard_b]);
4502
4503        fs::remove_dir_all(dir).ok();
4504    }
4505
4506    #[test]
4507    fn qwen_validate_paths_accepts_existing_files() {
4508        let dir = temp_test_dir("mold-qwen-validate-ok");
4509        let shard_a = touch(&dir, "transformer-00001-of-00002.safetensors");
4510        let shard_b = touch(&dir, "transformer-00002-of-00002.safetensors");
4511        let vae = touch(&dir, "vae.safetensors");
4512        let tokenizer = touch(&dir, "tokenizer.json");
4513        let gguf = touch(&dir, "transformer.gguf");
4514
4515        let sharded = QwenImageEngine::new(
4516            "qwen-image:bf16".to_string(),
4517            qwen_image_model_paths(
4518                dir.join("transformer.safetensors"),
4519                vec![shard_a, shard_b],
4520                vae.clone(),
4521                Some(tokenizer.clone()),
4522            ),
4523            LoadStrategy::Sequential,
4524            0,
4525            false,
4526            None,
4527        );
4528        assert_eq!(sharded.validate_paths().unwrap(), tokenizer);
4529        assert!(!sharded.detect_is_quantized());
4530
4531        let quantized = QwenImageEngine::new(
4532            "qwen-image:q4".to_string(),
4533            qwen_image_model_paths(gguf, vec![], vae, Some(dir.join("tokenizer.json"))),
4534            LoadStrategy::Sequential,
4535            0,
4536            false,
4537            None,
4538        );
4539        assert!(quantized.detect_is_quantized());
4540
4541        fs::remove_dir_all(dir).ok();
4542    }
4543
4544    #[test]
4545    fn qwen_validate_paths_requires_text_tokenizer() {
4546        let dir = temp_test_dir("mold-qwen-validate-missing");
4547        let engine = QwenImageEngine::new(
4548            "qwen-image:q4".to_string(),
4549            qwen_image_model_paths(
4550                dir.join("transformer.gguf"),
4551                vec![],
4552                dir.join("vae.safetensors"),
4553                None,
4554            ),
4555            LoadStrategy::Sequential,
4556            0,
4557            false,
4558            None,
4559        );
4560
4561        let err = engine.validate_paths().unwrap_err();
4562        assert!(err.to_string().contains("text tokenizer path required"));
4563
4564        fs::remove_dir_all(dir).ok();
4565    }
4566
4567    #[test]
4568    fn qwen_image_loads_text_tokenizer_through_shared_pool() {
4569        let dir = temp_test_dir("mold-qwen-tokenizer-pool");
4570        let tokenizer_path = dir.join("tokenizer.json");
4571        tokenizers::Tokenizer::new(BPE::default())
4572            .save(&tokenizer_path, false)
4573            .unwrap();
4574
4575        let shared_pool = Arc::new(Mutex::new(SharedPool::new()));
4576        let pooled = shared_pool
4577            .lock()
4578            .unwrap()
4579            .load_tokenizer(&tokenizer_path)
4580            .unwrap();
4581
4582        let engine = QwenImageEngine::new(
4583            "qwen-image:q4".to_string(),
4584            qwen_image_model_paths(
4585                dir.join("transformer.gguf"),
4586                vec![],
4587                dir.join("vae.safetensors"),
4588                Some(tokenizer_path.clone()),
4589            ),
4590            LoadStrategy::Sequential,
4591            0,
4592            false,
4593            Some(shared_pool),
4594        );
4595
4596        let loaded = engine.load_text_tokenizer(&tokenizer_path).unwrap();
4597
4598        assert!(Arc::ptr_eq(&pooled, &loaded));
4599        fs::remove_dir_all(dir).ok();
4600    }
4601
4602    #[test]
4603    fn qwen_image_loads_vae_tensors_through_shared_pool() {
4604        let dir = temp_test_dir("mold-qwen-vae-pool");
4605        let vae_path = dir.join("vae.safetensors");
4606        let weight = 1.0f32.to_le_bytes();
4607        let mut tensors = HashMap::new();
4608        tensors.insert(
4609            "encoder.conv_in.weight".to_string(),
4610            TensorView::new(SafeDtype::F32, vec![1], &weight).unwrap(),
4611        );
4612        serialize_to_file(&tensors, &None, &vae_path).unwrap();
4613
4614        let shared_pool = Arc::new(Mutex::new(SharedPool::new()));
4615        let pooled = shared_pool
4616            .lock()
4617            .unwrap()
4618            .load_safetensors_cpu_tensors(std::slice::from_ref(&vae_path))
4619            .unwrap()
4620            .unwrap();
4621
4622        let engine = QwenImageEngine::new(
4623            "qwen-image:q4".to_string(),
4624            qwen_image_model_paths(
4625                dir.join("transformer.gguf"),
4626                vec![],
4627                vae_path.clone(),
4628                Some(dir.join("tokenizer.json")),
4629            ),
4630            LoadStrategy::Sequential,
4631            0,
4632            false,
4633            Some(shared_pool),
4634        );
4635
4636        let loaded = engine.load_vae_cpu_tensors().unwrap().unwrap();
4637
4638        assert!(Arc::ptr_eq(&pooled, &loaded));
4639        fs::remove_dir_all(dir).ok();
4640    }
4641
4642    #[test]
4643    fn qwen_img2img_uses_minus_one_to_one_source_normalization() {
4644        assert_eq!(
4645            QwenImageEngine::img2img_source_normalize_range(),
4646            img_utils::NormalizeRange::MinusOneToOne
4647        );
4648    }
4649}