Skip to main content

mold_inference/qwen_image/
pipeline.rs

1//! Qwen-Image-2512 inference engine.
2//!
3//! Pipeline: Qwen2.5-VL text encoder -> QwenImageTransformer2DModel -> QwenImage VAE
4//!
5//! Architecture follows Z-Image closely (both from Alibaba/Tongyi):
6//! - Dual-stream transformer with joint attention and 3D RoPE
7//! - Flow-matching Euler discrete scheduler with dynamic shifting
8//! - Drop-and-reload for text encoder to manage VRAM
9//! - Both Eager and Sequential loading modes
10//!
11//! Key differences from Z-Image:
12//! - 60 identical dual-stream blocks (no noise_refiner/context_refiner)
13//! - Qwen2.5-VL text encoder (hidden_size=3584) instead of Qwen3 (2560)
14//! - Custom VAE with per-channel latent normalization
15//! - Official diffusers-style exponential time shift with dynamic per-image stretch
16
17use anyhow::{bail, Result};
18use candle_core::{DType, Device, IndexOp, Tensor, D};
19use candle_transformers::models::z_image::postprocess_image;
20use candle_transformers::quantized_var_builder;
21use mold_core::{fit_to_target_area, GenerateRequest, GenerateResponse, ImageData, ModelPaths};
22use std::collections::HashMap;
23use std::path::Path;
24use std::sync::{Arc, Mutex};
25use std::time::Instant;
26use tokenizers::Tokenizer;
27
28use super::quantized_transformer::QuantizedQwenImageTransformer2DModel;
29use super::sampling::{image_seq_len, QwenImageScheduler};
30use super::transformer::{QwenImageConfig, QwenImageTransformer2DModel};
31use super::vae::QwenImageVae;
32use crate::cache::{
33    clear_cache, prompt_text_key, CachedTensor, LruCache, DEFAULT_PROMPT_CACHE_CAPACITY,
34};
35use crate::device::{
36    effective_device_ref, fits_in_memory, fmt_gb, free_vram_bytes, memory_status_string,
37    preflight_memory_check, qwen2_vram_threshold, should_use_gpu, usable_free_vram_bytes,
38};
39use crate::encoders;
40use crate::engine::{rand_seed, InferenceEngine, LoadStrategy};
41use crate::engine_base::EngineBase;
42use crate::image::{build_output_metadata, encode_image};
43use crate::img_utils;
44use crate::progress::{ProgressCallback, ProgressEvent, ProgressReporter};
45use crate::upscaler::tiling::{upscale_with_tiling, TilingConfig};
46
47/// Minimum free VRAM (bytes) required to place Qwen-Image VAE on GPU.
48/// The VAE weights are ~300MB; decode workspace at 1024x1024 needs ~1-2GB.
49const VAE_DECODE_VRAM_THRESHOLD: u64 = 2_500_000_000;
50// Use a single space rather than an empty string so the unconditional CFG path
51// stays explicit after Qwen prompt templating and token windowing.
52const QWEN_EMPTY_NEGATIVE_PROMPT: &str = " ";
53const QWEN_NATIVE_WIDTH: usize = 1328;
54const QWEN_NATIVE_HEIGHT: usize = 1328;
55const QWEN_GGUF_NATIVE_CFG_HEADROOM: u64 = 14_000_000_000;
56const QWEN_GGUF_MIN_CFG_HEADROOM: u64 = 3_000_000_000;
57const QWEN_VAE_TILE_SIZES: [u32; 3] = [64, 32, 16];
58const QWEN_IMAGE_EDIT_VAE_AREA: u32 = 1024 * 1024;
59const QWEN_IMAGE_EDIT_SYSTEM_PROMPT: &str = "Describe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.";
60
61/// Minimum free VRAM for BF16 Qwen2.5-VL 7B text encoder on GPU.
62/// ~14GB model + 2GB headroom.
63const QWEN2_FP16_VRAM_THRESHOLD: u64 = 16_000_000_000;
64/// Extra residual VRAM required before keeping Qwen2.5 on GPU after a prompt
65/// cache miss. The denoise/VAE reserves cover known workspaces; this absorbs
66/// allocator fragmentation and backend scratch buffers.
67const QWEN2_HOT_TE_RESIDENCY_HEADROOM: u64 = 1_000_000_000;
68
69#[derive(Debug, Clone, Copy, PartialEq, Eq)]
70enum Qwen2TextEncoderMode {
71    Auto,
72    Gpu,
73    CpuStage,
74    Cpu,
75}
76
77impl Qwen2TextEncoderMode {
78    fn from_env() -> Self {
79        match std::env::var("MOLD_QWEN2_TEXT_ENCODER_MODE")
80            .unwrap_or_default()
81            .to_ascii_lowercase()
82            .as_str()
83        {
84            "gpu" => Self::Gpu,
85            "cpu-stage" => Self::CpuStage,
86            "cpu_stage" => Self::CpuStage,
87            "cpu" => Self::Cpu,
88            _ => Self::Auto,
89        }
90    }
91}
92
93#[derive(Debug, Clone, Copy, PartialEq, Eq)]
94struct Qwen2TextEncoderPlan {
95    use_gpu: bool,
96    use_cpu_staging: bool,
97}
98
99#[derive(Debug, Clone)]
100struct ResolvedQwen2TextEncoder {
101    paths: Vec<std::path::PathBuf>,
102    vision_paths: Vec<std::path::PathBuf>,
103    is_gguf: bool,
104    variant_label: String,
105    size_bytes: u64,
106    auto_use_gpu: bool,
107}
108
109#[derive(Debug, Clone, Copy, PartialEq, Eq)]
110enum Qwen2TextEncoderUsage {
111    Sequential,
112    Resident,
113}
114
115#[derive(Debug, Clone, Copy, PartialEq, Eq)]
116enum Qwen2TextEncoderPostEncodeAction {
117    KeepGpu,
118    ParkCpu,
119    Drop,
120}
121
122#[derive(Debug, Clone, Copy)]
123struct Qwen2TextEncoderResidencyInput {
124    on_gpu: bool,
125    is_quantized: bool,
126    is_metal: bool,
127    keep_te_ram: bool,
128    prompt_cache_miss: bool,
129    transformer_resident: bool,
130    free_vram_bytes: u64,
131    required_vram_bytes: u64,
132}
133
134#[derive(Debug, Clone, Copy)]
135struct QwenTensorStats {
136    min: f32,
137    max: f32,
138    mean: f32,
139    nan_count: u64,
140    pos_inf_count: u64,
141    neg_inf_count: u64,
142    total: usize,
143}
144
145/// Check if a Qwen-Image safetensors checkpoint stores weights in FP8 (F8_E4M3).
146/// Uses filename pattern first, then dtype probing as fallback.
147fn safetensors_is_fp8(path: &Path) -> bool {
148    // Filename-based detection
149    if path.to_str().map(|s| s.contains("fp8")).unwrap_or(false) {
150        return true;
151    }
152    // Dtype probing — try both ComfyUI and diffusers key names
153    let Ok(tensors) = (unsafe { candle_core::safetensors::MmapedSafetensors::multi(&[path]) })
154    else {
155        return false;
156    };
157    for key in ["x_embedder.weight", "img_in.weight"] {
158        if let Ok(t) = tensors.load(key, &Device::Cpu) {
159            return t.dtype() == DType::F8E4M3;
160        }
161    }
162    false
163}
164
165/// Check if text encoder safetensors contain FP8 weights.
166/// Uses filename pattern first (reliable for known ComfyUI FP8 models),
167/// then falls back to dtype probing.
168fn text_encoder_is_fp8(paths: &[std::path::PathBuf]) -> bool {
169    // Filename-based detection (ComfyUI FP8 models have "fp8" in name)
170    if paths
171        .iter()
172        .any(|p| p.to_str().map(|s| s.contains("fp8")).unwrap_or(false))
173    {
174        return true;
175    }
176    // Dtype probing fallback — try common key names
177    let Some(first) = paths.first() else {
178        return false;
179    };
180    let Ok(tensors) = (unsafe { candle_core::safetensors::MmapedSafetensors::multi(&[first]) })
181    else {
182        return false;
183    };
184    for key in [
185        "model.embed_tokens.weight",
186        "model.layers.0.self_attn.q_proj.weight",
187    ] {
188        if let Ok(t) = tensors.load(key, &Device::Cpu) {
189            return t.dtype() == DType::F8E4M3;
190        }
191    }
192    false
193}
194
195/// Loaded Qwen-Image model components, ready for inference.
196struct LoadedQwenImage {
197    /// Transformer wrapped in Option for drop-and-reload pattern.
198    transformer: Option<QwenImageTransformer>,
199    text_encoder: encoders::qwen2_text::Qwen2TextEncoder,
200    vae: QwenImageVae,
201    vae_path: std::path::PathBuf,
202    transformer_cfg: QwenImageConfig,
203    /// GPU device for transformer + denoising
204    device: Device,
205    /// Device where the VAE lives (may be CPU if VRAM is tight)
206    vae_device: Device,
207    dtype: DType,
208}
209
210#[allow(clippy::large_enum_variant)]
211enum QwenImageTransformer {
212    BF16(QwenImageTransformer2DModel),
213    Quantized(QuantizedQwenImageTransformer2DModel),
214    Offloaded(super::offload::OffloadedQwenImageTransformer),
215}
216
217#[derive(Clone)]
218struct CachedPromptConditioning {
219    hidden_states: CachedTensor,
220    valid_len: usize,
221}
222
223impl CachedPromptConditioning {
224    fn from_parts(hidden_states: &Tensor, valid_len: usize) -> Result<Self> {
225        Ok(Self {
226            hidden_states: CachedTensor::from_tensor(hidden_states)?,
227            valid_len,
228        })
229    }
230
231    fn restore(&self, device: &Device, dtype: DType) -> Result<(Tensor, Tensor)> {
232        let hidden_states = self.hidden_states.restore(device, dtype)?;
233        let mut mask = vec![0u8; hidden_states.dim(1)?];
234        for value in &mut mask[..self.valid_len] {
235            *value = 1;
236        }
237        let attention_mask = Tensor::from_vec(mask, (1, hidden_states.dim(1)?), device)?;
238        Ok((hidden_states, attention_mask))
239    }
240}
241
242fn pad_text_conditioning(
243    hidden_states: &Tensor,
244    attention_mask: &Tensor,
245    target_len: usize,
246) -> Result<(Tensor, Tensor)> {
247    let seq_len = hidden_states.dim(1)?;
248    if seq_len == target_len {
249        return Ok((hidden_states.clone(), attention_mask.clone()));
250    }
251    if seq_len > target_len {
252        bail!("cannot shrink text conditioning from {seq_len} to {target_len}");
253    }
254
255    let hidden_dim = hidden_states.dim(2)?;
256    let pad_len = target_len - seq_len;
257    let pad_hs = Tensor::zeros(
258        (hidden_states.dim(0)?, pad_len, hidden_dim),
259        hidden_states.dtype(),
260        hidden_states.device(),
261    )?;
262    let pad_mask = Tensor::zeros(
263        (attention_mask.dim(0)?, pad_len),
264        attention_mask.dtype(),
265        attention_mask.device(),
266    )?;
267
268    Ok((
269        Tensor::cat(&[hidden_states, &pad_hs], 1)?,
270        Tensor::cat(&[attention_mask, &pad_mask], 1)?,
271    ))
272}
273
274fn align_cfg_conditioning(
275    cond_hs: &Tensor,
276    cond_mask: &Tensor,
277    uncond_hs: &Tensor,
278    uncond_mask: &Tensor,
279) -> Result<((Tensor, Tensor), (Tensor, Tensor))> {
280    let target_len = cond_hs.dim(1)?.max(uncond_hs.dim(1)?);
281    let cond = pad_text_conditioning(cond_hs, cond_mask, target_len)?;
282    let uncond = pad_text_conditioning(uncond_hs, uncond_mask, target_len)?;
283    Ok((cond, uncond))
284}
285
286impl QwenImageTransformer {
287    fn supports_cfg_batching(&self) -> bool {
288        match self {
289            Self::Quantized(model) => model.supports_cfg_batching(),
290            _ => true,
291        }
292    }
293
294    fn forward(
295        &self,
296        latents: &Tensor,
297        t: &Tensor,
298        encoder_hidden_states: &Tensor,
299        encoder_attention_mask: &Tensor,
300    ) -> Result<Tensor> {
301        match self {
302            Self::BF16(model) => {
303                Ok(model.forward(latents, t, encoder_hidden_states, encoder_attention_mask)?)
304            }
305            Self::Quantized(model) => {
306                Ok(model.forward(latents, t, encoder_hidden_states, encoder_attention_mask)?)
307            }
308            Self::Offloaded(model) => {
309                model.forward(latents, t, encoder_hidden_states, encoder_attention_mask)
310            }
311        }
312    }
313
314    fn forward_packed(
315        &self,
316        packed_latents: &Tensor,
317        t: &Tensor,
318        encoder_hidden_states: &Tensor,
319        encoder_attention_mask: &Tensor,
320        img_shapes: &[(usize, usize, usize)],
321    ) -> Result<Tensor> {
322        match self {
323            Self::BF16(model) => Ok(model.forward_packed(
324                packed_latents,
325                t,
326                encoder_hidden_states,
327                encoder_attention_mask,
328                img_shapes,
329            )?),
330            Self::Quantized(model) => Ok(model.forward_packed(
331                packed_latents,
332                t,
333                encoder_hidden_states,
334                encoder_attention_mask,
335                img_shapes,
336            )?),
337            Self::Offloaded(model) => model.forward_packed(
338                packed_latents,
339                t,
340                encoder_hidden_states,
341                encoder_attention_mask,
342                img_shapes,
343            ),
344        }
345    }
346}
347
348/// Qwen-Image-2512 inference engine.
349pub struct QwenImageEngine {
350    base: EngineBase<LoadedQwenImage>,
351    prompt_cache: Mutex<LruCache<String, CachedPromptConditioning>>,
352    offload: bool,
353    /// Per-request placement override.
354    pending_placement: Option<mold_core::types::DevicePlacement>,
355    /// Per-request LoRA stack. Captured at the start of `generate()`,
356    /// cleared on exit. The transformer-load path consults this when
357    /// constructing the `VarBuilder` so the LoRA-merged weights land
358    /// before any forward pass runs.
359    pending_loras: Vec<mold_core::LoraWeight>,
360    /// Fingerprint of the LoRA stack currently baked into the loaded
361    /// transformer. Eager-mode generates compare against this to decide
362    /// whether to rebuild — an unchanged stack reuses the previously
363    /// merged weights. Currently always recomputed at load time
364    /// (same correctness-first stance as the sibling flux2 / sd3 / sdxl
365    /// / z-image early ports); the fingerprint API is in place for the
366    /// rebuild-elision follow-up.
367    #[allow(dead_code)]
368    active_lora_fingerprint: Vec<QwenImageLoraFingerprint>,
369    shared_pool: Option<Arc<Mutex<crate::shared_pool::SharedPool>>>,
370}
371
372/// Order-sensitive fingerprint of a single LoRA adapter (path-hash + scale).
373#[derive(Clone, PartialEq, Eq, Debug)]
374#[allow(dead_code)]
375struct QwenImageLoraFingerprint {
376    path_hash: u64,
377    scale_bits: u64,
378}
379
380impl QwenImageLoraFingerprint {
381    #[allow(dead_code)]
382    fn from_lora(lora: &mold_core::LoraWeight) -> Self {
383        Self {
384            path_hash: super::lora::lora_path_hash(&lora.path),
385            scale_bits: lora.scale.to_bits(),
386        }
387    }
388}
389
390#[allow(dead_code)]
391fn fingerprint_stack(loras: &[mold_core::LoraWeight]) -> Vec<QwenImageLoraFingerprint> {
392    loras
393        .iter()
394        .map(QwenImageLoraFingerprint::from_lora)
395        .collect()
396}
397
398/// Resolve the effective LoRA list for a request. Mirrors `flux::pipeline::
399/// effective_loras` — accepts both `lora` (singular, legacy) and `loras`
400/// (plural, current). Entries with a near-zero scale are dropped.
401fn effective_loras(req: &mold_core::GenerateRequest) -> Vec<mold_core::LoraWeight> {
402    /// Match the FLUX threshold so the user-facing semantics are
403    /// identical across families.
404    const ZERO_SCALE_EPS: f64 = 1e-8;
405
406    let raw: Vec<mold_core::LoraWeight> = if let Some(plural) = &req.loras {
407        if !plural.is_empty() {
408            plural.clone()
409        } else {
410            req.lora.iter().cloned().collect()
411        }
412    } else {
413        req.lora.iter().cloned().collect()
414    };
415
416    raw.into_iter()
417        .filter(|w| {
418            let keep = w.scale.abs() > ZERO_SCALE_EPS;
419            if !keep {
420                tracing::debug!(
421                    path = w.path.as_str(),
422                    scale = w.scale,
423                    "dropping zero-scale LoRA from effective Qwen-Image stack"
424                );
425            }
426            keep
427        })
428        .collect()
429}
430
431impl QwenImageEngine {
432    fn is_edit_family(&self) -> bool {
433        self.base.model_name.starts_with("qwen-image-edit")
434    }
435
436    fn should_preload_text_encoder(&self) -> bool {
437        !self.is_edit_family()
438    }
439
440    fn text_encoder_load_dtype(use_gpu: bool, gpu_dtype: DType) -> DType {
441        if use_gpu {
442            gpu_dtype
443        } else {
444            // Candle CPU matmul does not support BF16 for the Qwen2.5 encoder path.
445            // Keep CPU language/vision encoding in F32 and use quantized GGUF when
446            // lower host residency is needed.
447            DType::F32
448        }
449    }
450
451    fn transformer_config(&self) -> QwenImageConfig {
452        if self.is_edit_family() {
453            QwenImageConfig::qwen_image_edit_2511()
454        } else {
455            QwenImageConfig::qwen_image_2512()
456        }
457    }
458
459    fn qwen_image_edit_prompt(prompt: &str, image_count: usize) -> String {
460        let picture_prefix = (0..image_count)
461            .map(|idx| {
462                format!(
463                    "Picture {}: <|vision_start|><|image_pad|><|vision_end|>",
464                    idx + 1
465                )
466            })
467            .collect::<String>();
468        format!(
469            "<|im_start|>system\n{QWEN_IMAGE_EDIT_SYSTEM_PROMPT}<|im_end|>\n<|im_start|>user\n{picture_prefix}{prompt}<|im_end|>\n<|im_start|>assistant\n"
470        )
471    }
472
473    fn qwen_image_edit_image_dims(image: &[u8], target_area: u32) -> Result<(u32, u32)> {
474        let img = image::load_from_memory(image)?;
475        Ok(fit_to_target_area(
476            img.width().max(1),
477            img.height().max(1),
478            target_area,
479            16,
480        ))
481    }
482
483    fn pack_latents_4d(latents: &Tensor) -> Result<Tensor> {
484        let (batch, channels, height, width) = latents.dims4()?;
485        let height_blocks = height / 2;
486        let width_blocks = width / 2;
487        latents
488            .reshape((batch, channels, height_blocks, 2, width_blocks, 2))?
489            .permute((0, 2, 4, 1, 3, 5))?
490            .reshape((batch, height_blocks * width_blocks, channels * 4))
491            .map_err(Into::into)
492    }
493
494    fn unpack_latents_packed(latents: &Tensor, latent_h: usize, latent_w: usize) -> Result<Tensor> {
495        let batch = latents.dim(0)?;
496        latents
497            .reshape((batch, latent_h / 2, latent_w / 2, 16, 2, 2))?
498            .permute((0, 3, 1, 4, 2, 5))?
499            .reshape((batch, 16, latent_h, latent_w))
500            .map_err(Into::into)
501    }
502
503    fn img2img_source_normalize_range() -> img_utils::NormalizeRange {
504        img_utils::NormalizeRange::MinusOneToOne
505    }
506
507    fn is_oom_error(err: &impl std::fmt::Display) -> bool {
508        // TODO: Replace this with typed backend inspection if candle exposes
509        // one. Today the fallback ladder has to key off the backend error text.
510        let msg = err.to_string();
511        msg.contains("OUT_OF_MEMORY")
512            || msg.contains("out of memory")
513            || msg.contains("cudaErrorMemoryAllocation")
514    }
515
516    fn with_cuda_oom_cpu_fallback<T, FPrimary, FFallback, FOom>(
517        primary: FPrimary,
518        fallback: FFallback,
519        is_cuda: bool,
520        sync_device: &Device,
521        progress: &ProgressReporter,
522        oom_message: &str,
523        is_oom: FOom,
524    ) -> Result<T>
525    where
526        FPrimary: FnOnce() -> Result<T>,
527        FFallback: FnOnce() -> Result<T>,
528        FOom: Fn(&anyhow::Error) -> bool,
529    {
530        match primary() {
531            Ok(value) => Ok(value),
532            Err(err) if is_cuda && is_oom(&err) => {
533                progress.info(oom_message);
534                sync_device.synchronize()?;
535                fallback()
536            }
537            Err(err) => Err(err),
538        }
539    }
540
541    #[allow(clippy::too_many_arguments)]
542    fn with_cuda_tiled_then_cpu_fallback<T, FPrimary, FTiled, FCpu, FOom>(
543        primary: FPrimary,
544        tiled: FTiled,
545        cpu_fallback: FCpu,
546        is_cuda: bool,
547        prefer_tiled: bool,
548        sync_device: &Device,
549        progress: &ProgressReporter,
550        tiled_message: &str,
551        cpu_message: &str,
552        is_oom: FOom,
553    ) -> Result<T>
554    where
555        FPrimary: FnOnce() -> Result<T>,
556        FTiled: FnOnce() -> Result<T>,
557        FCpu: FnOnce() -> Result<T>,
558        FOom: Fn(&anyhow::Error) -> bool,
559    {
560        if is_cuda && prefer_tiled {
561            progress.info("Selecting tiled GPU VAE decode proactively");
562            match tiled() {
563                Ok(value) => return Ok(value),
564                Err(tile_err) if is_oom(&tile_err) => {
565                    progress.info(cpu_message);
566                    sync_device.synchronize()?;
567                    return cpu_fallback();
568                }
569                Err(tile_err) => return Err(tile_err),
570            }
571        }
572
573        match primary() {
574            Ok(value) => Ok(value),
575            Err(err) if is_cuda && is_oom(&err) => {
576                progress.info(tiled_message);
577                sync_device.synchronize()?;
578                match tiled() {
579                    Ok(value) => Ok(value),
580                    Err(tile_err) if is_oom(&tile_err) => {
581                        progress.info(cpu_message);
582                        sync_device.synchronize()?;
583                        cpu_fallback()
584                    }
585                    Err(tile_err) => Err(tile_err),
586                }
587            }
588            Err(err) => Err(err),
589        }
590    }
591
592    fn qwen_vae_decode_workspace_bytes(width: u32, height: u32) -> u64 {
593        let pixels = width as u64 * height as u64;
594        // Qwen's 3D causal VAE decode has a much larger transient workspace
595        // than the final RGB tensor. This factor is intentionally conservative:
596        // native 1328² requests reserve ~7.2 GB, while small 512² requests stay
597        // below the proactive tiling threshold.
598        pixels.saturating_mul(4).saturating_mul(1024)
599    }
600
601    fn should_proactively_tile_vae_decode(
602        width: u32,
603        height: u32,
604        vae_is_cuda: bool,
605        free_vram_bytes: u64,
606    ) -> bool {
607        if !vae_is_cuda || free_vram_bytes == 0 {
608            return false;
609        }
610        let native_pixels = (QWEN_NATIVE_WIDTH * QWEN_NATIVE_HEIGHT) as u64;
611        let pixels = width as u64 * height as u64;
612        if pixels < native_pixels.saturating_mul(3) / 4 {
613            return false;
614        }
615        let required = VAE_DECODE_VRAM_THRESHOLD
616            .saturating_add(Self::qwen_vae_decode_workspace_bytes(width, height));
617        free_vram_bytes < required
618    }
619
620    fn qwen2_text_encoder_post_encode_action(
621        input: Qwen2TextEncoderResidencyInput,
622    ) -> Qwen2TextEncoderPostEncodeAction {
623        if !input.on_gpu {
624            return Qwen2TextEncoderPostEncodeAction::Drop;
625        }
626        if input.prompt_cache_miss
627            && input.transformer_resident
628            && !input.is_metal
629            && input.free_vram_bytes >= input.required_vram_bytes
630        {
631            return Qwen2TextEncoderPostEncodeAction::KeepGpu;
632        }
633        if input.keep_te_ram && !input.is_metal && !input.is_quantized {
634            return Qwen2TextEncoderPostEncodeAction::ParkCpu;
635        }
636        Qwen2TextEncoderPostEncodeAction::Drop
637    }
638
639    fn qwen2_hot_text_encoder_required_vram(
640        width: u32,
641        height: u32,
642        cfg_batch: u32,
643        dtype: DType,
644    ) -> u64 {
645        crate::device::activation_bytes(
646            width,
647            height,
648            cfg_batch,
649            crate::device::dtype_bytes(dtype),
650            crate::device::ActivationFamily::QwenImageDit,
651        )
652        .saturating_add(VAE_DECODE_VRAM_THRESHOLD)
653        .saturating_add(Self::qwen_vae_decode_workspace_bytes(width, height))
654        .saturating_add(QWEN2_HOT_TE_RESIDENCY_HEADROOM)
655    }
656
657    fn decode_vae_tiled(
658        latents: &Tensor,
659        vae: &QwenImageVae,
660        vae_device: &Device,
661        progress: &ProgressReporter,
662    ) -> Result<Tensor> {
663        for tile_size in QWEN_VAE_TILE_SIZES {
664            let overlap = (tile_size / 4).max(4);
665            progress.info(&format!(
666                "Retrying VAE decode with tiled GPU decode (tile {} overlap {})",
667                tile_size, overlap
668            ));
669            let config = TilingConfig {
670                tile_size,
671                overlap,
672                min_tile_size: 16,
673            };
674            let forward = |tile: &Tensor| {
675                let tile = tile.to_device(vae_device)?.to_dtype(DType::F32)?;
676                vae.decode(&tile).map_err(Into::into)
677            };
678            // `upscale_with_tiling` is reused here because Qwen-Image VAE decode
679            // is guaranteed to return 3-channel RGB. If a future VAE family
680            // changes that contract, this call site needs a tiler that handles
681            // arbitrary output channel counts.
682            match upscale_with_tiling(latents, &forward, 8, &config, &Device::Cpu, progress) {
683                Ok(image) => return Ok(image),
684                Err(e) if vae_device.is_cuda() && Self::is_oom_error(&e) => {
685                    if let Err(sync_err) = vae_device.synchronize() {
686                        tracing::warn!(
687                            "failed to synchronize CUDA device after tiled VAE OOM: {sync_err}"
688                        );
689                    }
690                }
691                Err(e) => return Err(e),
692            }
693        }
694
695        bail!("tiled VAE decode still ran out of memory")
696    }
697
698    fn decode_vae_with_fallback<F>(
699        latents: &Tensor,
700        vae: &QwenImageVae,
701        vae_device: &Device,
702        sync_device: &Device,
703        progress: &ProgressReporter,
704        prefer_tiled: bool,
705        load_cpu_vae: F,
706    ) -> Result<Tensor>
707    where
708        F: FnOnce() -> Result<QwenImageVae>,
709    {
710        let decode_latents = latents.to_device(vae_device)?.to_dtype(DType::F32)?;
711        Self::debug_tensor_stats("latents_pre_vae", &decode_latents);
712        Self::with_cuda_tiled_then_cpu_fallback(
713            || vae.decode(&decode_latents).map_err(Into::into),
714            || Self::decode_vae_tiled(latents, vae, vae_device, progress),
715            || {
716                let cpu_vae = load_cpu_vae()?;
717                let cpu_latents = latents.to_device(&Device::Cpu)?.to_dtype(DType::F32)?;
718                cpu_vae.decode(&cpu_latents).map_err(Into::into)
719            },
720            vae_device.is_cuda(),
721            prefer_tiled,
722            sync_device,
723            progress,
724            "VAE decode OOM on GPU — retrying with tiled GPU decode",
725            "Tiled GPU VAE decode OOM — retrying on CPU",
726            Self::is_oom_error,
727        )
728    }
729
730    /// Encode a source image through the VAE with GPU→CPU OOM fallback.
731    #[allow(clippy::too_many_arguments)]
732    fn encode_vae_with_fallback(
733        source_bytes: &[u8],
734        width: u32,
735        height: u32,
736        vae: &QwenImageVae,
737        vae_device: &Device,
738        sync_device: &Device,
739        progress: &ProgressReporter,
740        load_cpu_vae: impl FnOnce() -> Result<QwenImageVae>,
741    ) -> Result<Tensor> {
742        progress.stage_start("Encoding source image (VAE)");
743        let encode_start = Instant::now();
744
745        // Qwen-Image VAE expects [-1, 1] normalized pixels
746        let source_tensor = img_utils::decode_source_image(
747            source_bytes,
748            width,
749            height,
750            Self::img2img_source_normalize_range(),
751            vae_device,
752            DType::F32,
753        )?;
754
755        let result = Self::with_cuda_oom_cpu_fallback(
756            || vae.encode(&source_tensor).map_err(Into::into),
757            || {
758                let cpu_vae = load_cpu_vae()?;
759                let cpu_source = img_utils::decode_source_image(
760                    source_bytes,
761                    width,
762                    height,
763                    Self::img2img_source_normalize_range(),
764                    &Device::Cpu,
765                    DType::F32,
766                )?;
767                cpu_vae.encode(&cpu_source).map_err(Into::into)
768            },
769            vae_device.is_cuda(),
770            sync_device,
771            progress,
772            "VAE encode OOM on GPU — retrying on CPU",
773            Self::is_oom_error,
774        );
775
776        progress.stage_done("Encoding source image (VAE)", encode_start.elapsed());
777        result
778    }
779
780    fn choose_text_encoder_source(
781        preference: Option<&str>,
782        is_cuda: bool,
783        is_metal: bool,
784        free_vram: u64,
785        bf16_size_bytes: u64,
786        _usage: Qwen2TextEncoderUsage,
787    ) -> Result<ResolvedQwen2TextEncoder> {
788        match preference {
789            Some(tag) if tag != "auto" && tag != "bf16" => {
790                let variant = mold_core::manifest::find_qwen2_vl_variant(tag).ok_or_else(|| {
791                    anyhow::anyhow!(
792                        "unknown Qwen2.5-VL variant '{}'. Valid: bf16, auto, q8, q6, q5, q4, q3, q2",
793                        tag
794                    )
795                })?;
796                Ok(ResolvedQwen2TextEncoder {
797                    paths: vec![],
798                    vision_paths: vec![],
799                    is_gguf: true,
800                    variant_label: variant.tag.to_string(),
801                    size_bytes: variant.size_bytes,
802                    auto_use_gpu: should_use_gpu(
803                        is_cuda,
804                        is_metal,
805                        free_vram,
806                        qwen2_vram_threshold(variant.size_bytes),
807                    ),
808                })
809            }
810            Some("bf16") => Ok(ResolvedQwen2TextEncoder {
811                paths: vec![],
812                vision_paths: vec![],
813                is_gguf: false,
814                variant_label: "bf16".to_string(),
815                size_bytes: bf16_size_bytes,
816                auto_use_gpu: should_use_gpu(
817                    is_cuda,
818                    is_metal,
819                    free_vram,
820                    QWEN2_FP16_VRAM_THRESHOLD,
821                ),
822            }),
823            _ if is_metal => {
824                for tag in ["q6", "q4"] {
825                    let variant = mold_core::manifest::find_qwen2_vl_variant(tag)
826                        .expect("known Metal auto qwen2 variant missing");
827                    if fits_in_memory(
828                        is_cuda,
829                        is_metal,
830                        free_vram,
831                        qwen2_vram_threshold(variant.size_bytes),
832                    ) {
833                        return Ok(ResolvedQwen2TextEncoder {
834                            paths: vec![],
835                            vision_paths: vec![],
836                            is_gguf: true,
837                            variant_label: variant.tag.to_string(),
838                            size_bytes: variant.size_bytes,
839                            auto_use_gpu: true,
840                        });
841                    }
842                }
843                let fallback = mold_core::manifest::find_qwen2_vl_variant("q4")
844                    .expect("known Metal fallback qwen2 variant missing");
845                Ok(ResolvedQwen2TextEncoder {
846                    paths: vec![],
847                    vision_paths: vec![],
848                    is_gguf: true,
849                    variant_label: fallback.tag.to_string(),
850                    size_bytes: fallback.size_bytes,
851                    auto_use_gpu: true,
852                })
853            }
854            _ => {
855                let bf16_on_gpu =
856                    should_use_gpu(is_cuda, is_metal, free_vram, QWEN2_FP16_VRAM_THRESHOLD);
857                if bf16_on_gpu {
858                    return Ok(ResolvedQwen2TextEncoder {
859                        paths: vec![],
860                        vision_paths: vec![],
861                        is_gguf: false,
862                        variant_label: "bf16".to_string(),
863                        size_bytes: bf16_size_bytes,
864                        auto_use_gpu: true,
865                    });
866                }
867
868                if is_cuda {
869                    let fallback_tag = "q4";
870                    let fallback = mold_core::manifest::find_qwen2_vl_variant(fallback_tag)
871                        .expect("known CUDA fallback qwen2 variant missing");
872                    return Ok(ResolvedQwen2TextEncoder {
873                        paths: vec![],
874                        vision_paths: vec![],
875                        is_gguf: true,
876                        variant_label: fallback.tag.to_string(),
877                        size_bytes: fallback.size_bytes,
878                        auto_use_gpu: fits_in_memory(
879                            is_cuda,
880                            is_metal,
881                            free_vram,
882                            qwen2_vram_threshold(fallback.size_bytes),
883                        ),
884                    });
885                }
886
887                Ok(ResolvedQwen2TextEncoder {
888                    paths: vec![],
889                    vision_paths: vec![],
890                    is_gguf: false,
891                    variant_label: "bf16".to_string(),
892                    size_bytes: bf16_size_bytes,
893                    auto_use_gpu: false,
894                })
895            }
896        }
897    }
898
899    fn tensor_stats(tensor: &Tensor) -> Result<QwenTensorStats> {
900        let t = tensor.to_dtype(DType::F32)?;
901        let values = t.flatten_all()?.to_vec1::<f32>()?;
902        let mut min = f32::INFINITY;
903        let mut max = f32::NEG_INFINITY;
904        let mut sum = 0.0f64;
905        let mut finite_count = 0usize;
906        let mut nan_count = 0u64;
907        let mut pos_inf_count = 0u64;
908        let mut neg_inf_count = 0u64;
909        for value in &values {
910            if value.is_nan() {
911                nan_count += 1;
912            } else if *value == f32::INFINITY {
913                pos_inf_count += 1;
914            } else if *value == f32::NEG_INFINITY {
915                neg_inf_count += 1;
916            } else {
917                min = min.min(*value);
918                max = max.max(*value);
919                sum += *value as f64;
920                finite_count += 1;
921            }
922        }
923        let mean = if finite_count == 0 {
924            f32::NAN
925        } else {
926            (sum / finite_count as f64) as f32
927        };
928        if finite_count == 0 {
929            min = f32::NAN;
930            max = f32::NAN;
931        }
932        Ok(QwenTensorStats {
933            min,
934            max,
935            mean,
936            nan_count,
937            pos_inf_count,
938            neg_inf_count,
939            total: values.len(),
940        })
941    }
942
943    fn format_tensor_stats(name: &str, stats: QwenTensorStats) -> String {
944        format!(
945            "[qwen-debug] {name}: min={:.4} max={:.4} mean={:.4} NaN={}/{} ({:.1}%) +Inf={} -Inf={}",
946            stats.min,
947            stats.max,
948            stats.mean,
949            stats.nan_count,
950            stats.total,
951            stats.nan_count as f64 / stats.total.max(1) as f64 * 100.0,
952            stats.pos_inf_count,
953            stats.neg_inf_count
954        )
955    }
956
957    fn near_black_image_stats(stats: QwenTensorStats) -> bool {
958        if stats.nan_count > 0
959            || stats.pos_inf_count > 0
960            || stats.neg_inf_count > 0
961            || !stats.min.is_finite()
962            || !stats.max.is_finite()
963            || !stats.mean.is_finite()
964        {
965            return false;
966        }
967        let scale = if stats.max <= 1.0 { 1.0 } else { 255.0 };
968        stats.max <= 0.02 * scale && stats.mean <= 0.01 * scale
969    }
970
971    fn validate_qwen_tensor_boundary(name: &str, tensor: &Tensor) -> Result<QwenTensorStats> {
972        let stats = Self::tensor_stats(tensor)?;
973        if stats.nan_count > 0
974            || stats.pos_inf_count > 0
975            || stats.neg_inf_count > 0
976            || !stats.min.is_finite()
977            || !stats.max.is_finite()
978            || !stats.mean.is_finite()
979        {
980            bail!(
981                "Qwen diagnostic boundary '{name}' contains non-finite values: {}",
982                Self::format_tensor_stats(name, stats)
983            );
984        }
985        Ok(stats)
986    }
987
988    fn debug_tensor_stats(name: &str, tensor: &Tensor) {
989        if std::env::var_os("MOLD_QWEN_DEBUG").is_none() {
990            return;
991        }
992        match Self::tensor_stats(tensor) {
993            Ok(stats) => eprintln!("{}", Self::format_tensor_stats(name, stats)),
994            Err(err) => eprintln!("[qwen-debug] {name}: <failed: {err}>"),
995        }
996    }
997
998    pub fn new(
999        model_name: String,
1000        paths: ModelPaths,
1001        load_strategy: LoadStrategy,
1002        gpu_ordinal: usize,
1003        offload: bool,
1004        shared_pool: Option<Arc<Mutex<crate::shared_pool::SharedPool>>>,
1005    ) -> Self {
1006        Self {
1007            base: EngineBase::new(model_name, paths, load_strategy, gpu_ordinal),
1008            prompt_cache: Mutex::new(LruCache::new(DEFAULT_PROMPT_CACHE_CAPACITY)),
1009            offload,
1010            pending_placement: None,
1011            pending_loras: Vec::new(),
1012            active_lora_fingerprint: Vec::new(),
1013            shared_pool,
1014        }
1015    }
1016
1017    fn load_text_tokenizer(&self, tokenizer_path: &Path) -> Result<Arc<Tokenizer>> {
1018        if let Some(shared_pool) = &self.shared_pool {
1019            return shared_pool.lock().unwrap().load_tokenizer(tokenizer_path);
1020        }
1021        Tokenizer::from_file(tokenizer_path)
1022            .map(Arc::new)
1023            .map_err(|e| anyhow::anyhow!("failed to load Qwen2.5 tokenizer: {e}"))
1024    }
1025
1026    fn encode_prompt_cached(
1027        progress: &ProgressReporter,
1028        prompt_cache: &Mutex<LruCache<String, CachedPromptConditioning>>,
1029        text_encoder: &mut encoders::qwen2_text::Qwen2TextEncoder,
1030        prompt: &str,
1031        device: &Device,
1032        dtype: DType,
1033    ) -> Result<(Tensor, Tensor)> {
1034        let cache_key = prompt_text_key(prompt);
1035        if let Some(cached) = prompt_cache
1036            .lock()
1037            .expect("cache poisoned")
1038            .get_cloned(&cache_key)
1039        {
1040            progress.cache_hit("prompt conditioning");
1041            return cached.restore(device, dtype);
1042        }
1043
1044        progress.stage_start("Encoding prompt (Qwen2.5)");
1045        let encode_start = Instant::now();
1046        let (hidden_states, _attention_mask, valid_len) =
1047            text_encoder.encode(prompt, device, dtype)?;
1048        progress.stage_done("Encoding prompt (Qwen2.5)", encode_start.elapsed());
1049
1050        prompt_cache.lock().expect("cache poisoned").insert(
1051            cache_key,
1052            CachedPromptConditioning::from_parts(&hidden_states, valid_len)?,
1053        );
1054
1055        let mut mask = vec![0u8; hidden_states.dim(1)?];
1056        for value in &mut mask[..valid_len] {
1057            *value = 1;
1058        }
1059        let attention_mask = Tensor::from_vec(mask, (1, hidden_states.dim(1)?), device)?;
1060        Ok((hidden_states, attention_mask))
1061    }
1062
1063    fn spill_conditioning_to_cpu(
1064        hidden_states: Tensor,
1065        attention_mask: Tensor,
1066    ) -> Result<(Tensor, Tensor)> {
1067        Ok((
1068            hidden_states
1069                .to_device(&Device::Cpu)?
1070                .to_dtype(DType::F32)?,
1071            attention_mask.to_device(&Device::Cpu)?,
1072        ))
1073    }
1074
1075    fn maybe_spill_conditioning(
1076        use_cpu_staging: bool,
1077        hidden_states: Tensor,
1078        attention_mask: Tensor,
1079    ) -> Result<(Tensor, Tensor)> {
1080        if use_cpu_staging {
1081            Self::spill_conditioning_to_cpu(hidden_states, attention_mask)
1082        } else {
1083            Ok((hidden_states, attention_mask))
1084        }
1085    }
1086
1087    /// Resolve transformer shard paths.
1088    fn transformer_paths(&self) -> Vec<std::path::PathBuf> {
1089        if !self.base.paths.transformer_shards.is_empty() {
1090            self.base.paths.transformer_shards.clone()
1091        } else {
1092            vec![self.base.paths.transformer.clone()]
1093        }
1094    }
1095
1096    fn detect_is_quantized(&self) -> bool {
1097        self.base
1098            .paths
1099            .transformer
1100            .extension()
1101            .and_then(|e| e.to_str())
1102            .map(|e| e.eq_ignore_ascii_case("gguf"))
1103            .unwrap_or(false)
1104    }
1105
1106    /// Validate required paths exist.
1107    fn validate_paths(&self) -> Result<std::path::PathBuf> {
1108        let text_tokenizer_path =
1109            self.base.paths.text_tokenizer.as_ref().ok_or_else(|| {
1110                anyhow::anyhow!("text tokenizer path required for Qwen-Image models")
1111            })?;
1112        if !text_tokenizer_path.exists() {
1113            bail!(
1114                "text tokenizer file not found: {}",
1115                text_tokenizer_path.display()
1116            );
1117        }
1118
1119        let xformer_paths = self.transformer_paths();
1120        for path in &xformer_paths {
1121            if !path.exists() {
1122                bail!("transformer file not found: {}", path.display());
1123            }
1124        }
1125        if !self.base.paths.vae.exists() {
1126            bail!("VAE file not found: {}", self.base.paths.vae.display());
1127        }
1128
1129        Ok(text_tokenizer_path.clone())
1130    }
1131
1132    fn quantized_cuda_cfg_headroom(width: usize, height: usize) -> u64 {
1133        let native_pixels = (QWEN_NATIVE_WIDTH * QWEN_NATIVE_HEIGHT) as f64;
1134        let pixels = (width.max(1) * height.max(1)) as f64;
1135        let scaled =
1136            (QWEN_GGUF_NATIVE_CFG_HEADROOM as f64 * (pixels / native_pixels)).round() as u64;
1137        scaled.max(QWEN_GGUF_MIN_CFG_HEADROOM)
1138    }
1139
1140    fn should_split_cfg_quantized_cuda(
1141        transformer_size: u64,
1142        free_vram: u64,
1143        width: usize,
1144        height: usize,
1145    ) -> bool {
1146        if free_vram == 0 {
1147            // If VRAM probing fails, bias toward the safer split-CFG path
1148            // instead of assuming batched CFG will fit.
1149            return true;
1150        }
1151        let estimated_peak =
1152            transformer_size.saturating_add(Self::quantized_cuda_cfg_headroom(width, height));
1153        estimated_peak > free_vram
1154    }
1155
1156    /// Load transformer from disk.
1157    fn load_transformer(
1158        &self,
1159        device: &Device,
1160        dtype: DType,
1161        cfg: &QwenImageConfig,
1162        width: usize,
1163        height: usize,
1164    ) -> Result<QwenImageTransformer> {
1165        let active_loras = &self.pending_loras;
1166        let has_lora = !active_loras.is_empty();
1167        if self.detect_is_quantized() {
1168            let transformer_size = std::fs::metadata(&self.base.paths.transformer)
1169                .map(|m| m.len())
1170                .unwrap_or(0);
1171            // Reserve-adjusted reading: split-CFG is a budget decision.
1172            let free = usable_free_vram_bytes(self.base.gpu_ordinal).unwrap_or(0);
1173            let split_cfg_for_memory = device.is_cuda()
1174                && (self.offload
1175                    || Self::should_split_cfg_quantized_cuda(
1176                        transformer_size,
1177                        free,
1178                        width,
1179                        height,
1180                    ));
1181            if self.offload && device.is_cuda() {
1182                self.base.progress.info(
1183                    "Quantized Qwen CUDA offload requested — using low-memory split-CFG mode until GGUF block offload lands",
1184                );
1185            } else if split_cfg_for_memory {
1186                let estimated_peak = transformer_size
1187                    .saturating_add(Self::quantized_cuda_cfg_headroom(width, height));
1188                self.base.progress.info(&format!(
1189                    "Using low-memory quantized Qwen CUDA path (est. peak {}, {} free at {}x{})",
1190                    fmt_gb(estimated_peak),
1191                    fmt_gb(free),
1192                    width,
1193                    height,
1194                ));
1195            }
1196            let vb = if has_lora {
1197                let adapters = super::lora::load_lora_adapters(active_loras, &self.base.progress)?;
1198                let specs: Vec<super::lora::QwenImageLoraSpec<'_>> = adapters
1199                    .iter()
1200                    .zip(active_loras.iter())
1201                    .map(|(adapter, w)| super::lora::QwenImageLoraSpec {
1202                        adapter: adapter.as_ref(),
1203                        scale: w.scale,
1204                        path_hash: super::lora::lora_path_hash(&w.path),
1205                    })
1206                    .collect();
1207                super::lora::gguf_lora_var_builder(
1208                    &self.base.paths.transformer,
1209                    &specs,
1210                    device,
1211                    &self.base.progress,
1212                    None,
1213                )?
1214            } else {
1215                quantized_var_builder::VarBuilder::from_gguf(&self.base.paths.transformer, device)?
1216            };
1217            Ok(QwenImageTransformer::Quantized(
1218                QuantizedQwenImageTransformer2DModel::new(cfg, vb, device, !split_cfg_for_memory)?,
1219            ))
1220        } else {
1221            let xformer_paths = self.transformer_paths();
1222            let is_fp8 = xformer_paths
1223                .first()
1224                .map(|p| safetensors_is_fp8(p))
1225                .unwrap_or(false);
1226
1227            // FP8 weights stay as F8E4M3 in VRAM (~19.5GB, 1 byte/param).
1228            // Per-layer dequant to BF16 during forward adds ~113MB transient.
1229            // BF16 weights are 2 bytes/param (~40GB).
1230            let mem_size: u64 = xformer_paths
1231                .iter()
1232                .filter_map(|p| std::fs::metadata(p).ok())
1233                .map(|m| m.len())
1234                .sum();
1235            // Reserve-adjusted reading: should_offload budgets against this.
1236            let free = usable_free_vram_bytes(self.base.gpu_ordinal).unwrap_or(0);
1237            // Qwen-Image runs CFG by default; activation budget scales with
1238            // resolution to replace the previous fixed 3 GB heuristic.
1239            let activation_budget = crate::device::activation_bytes(
1240                width as u32,
1241                height as u32,
1242                2,
1243                crate::device::dtype_bytes(dtype),
1244                crate::device::ActivationFamily::QwenImageDit,
1245            );
1246            let use_offload =
1247                self.offload || crate::device::should_offload(mem_size, free, activation_budget);
1248
1249            if is_fp8 {
1250                self.base
1251                    .progress
1252                    .info("Detected FP8 safetensors — loading with scale dequantization");
1253            }
1254
1255            if use_offload {
1256                if has_lora {
1257                    bail!(
1258                        "Qwen-Image LoRA support is not yet wired through the block-offload \
1259                         transformer path. Disable offload (drop --offload / unset MOLD_OFFLOAD), \
1260                         or pick a checkpoint that fits without offload, to use LoRAs."
1261                    );
1262                }
1263                // Create TWO VarBuilders: GPU for blocks that fit, CPU for overflow.
1264                let (gpu_vb, cpu_vb) = if is_fp8 {
1265                    let gpu = crate::weight_loader::load_fp8_safetensors(
1266                        &xformer_paths,
1267                        device,
1268                        "Qwen-Image transformer (offload, GPU)",
1269                        &self.base.progress,
1270                    )?;
1271                    let cpu = crate::weight_loader::load_fp8_safetensors(
1272                        &xformer_paths,
1273                        &Device::Cpu,
1274                        "Qwen-Image transformer (offload, CPU)",
1275                        &self.base.progress,
1276                    )?;
1277                    (gpu, cpu)
1278                } else {
1279                    let gpu = crate::weight_loader::load_safetensors_with_progress(
1280                        &xformer_paths,
1281                        dtype,
1282                        device,
1283                        "Qwen-Image transformer (offload, GPU)",
1284                        &self.base.progress,
1285                    )?;
1286                    let cpu = unsafe {
1287                        candle_nn::VarBuilder::from_mmaped_safetensors(
1288                            &xformer_paths
1289                                .iter()
1290                                .map(|p| p.as_path())
1291                                .collect::<Vec<_>>(),
1292                            DType::BF16,
1293                            &Device::Cpu,
1294                        )?
1295                    };
1296                    (gpu, cpu)
1297                };
1298                Ok(QwenImageTransformer::Offloaded(
1299                    super::offload::OffloadedQwenImageTransformer::load(
1300                        gpu_vb,
1301                        cpu_vb,
1302                        cfg,
1303                        device,
1304                        self.base.gpu_ordinal,
1305                        &self.base.progress,
1306                    )?,
1307                ))
1308            } else {
1309                let xformer_vb = if has_lora {
1310                    self.build_bf16_lora_var_builder(
1311                        &xformer_paths,
1312                        dtype,
1313                        device,
1314                        is_fp8,
1315                        active_loras,
1316                    )?
1317                } else if is_fp8 {
1318                    crate::weight_loader::load_fp8_safetensors(
1319                        &xformer_paths,
1320                        device,
1321                        "Qwen-Image transformer",
1322                        &self.base.progress,
1323                    )?
1324                } else {
1325                    crate::weight_loader::load_safetensors_with_progress(
1326                        &xformer_paths,
1327                        dtype,
1328                        device,
1329                        "Qwen-Image transformer",
1330                        &self.base.progress,
1331                    )?
1332                };
1333                Ok(QwenImageTransformer::BF16(
1334                    QwenImageTransformer2DModel::new(cfg, xformer_vb)?,
1335                ))
1336            }
1337        }
1338    }
1339
1340    /// Construct a `VarBuilder` for the BF16/FP8 in-memory path with a
1341    /// LoRA-merging `SimpleBackend` wrapping the underlying mmap (or
1342    /// `NativeFp8Backend`). Each `vb.get()` call delivers a tensor with
1343    /// `W' = W + scale·(B @ A)` already merged in.
1344    fn build_bf16_lora_var_builder<'a>(
1345        &self,
1346        xformer_paths: &[std::path::PathBuf],
1347        dtype: DType,
1348        device: &Device,
1349        is_fp8: bool,
1350        loras: &[mold_core::LoraWeight],
1351    ) -> Result<candle_nn::VarBuilder<'a>> {
1352        let adapters = super::lora::load_lora_adapters(loras, &self.base.progress)?;
1353        let specs: Vec<super::lora::QwenImageLoraSpec<'_>> = adapters
1354            .iter()
1355            .zip(loras.iter())
1356            .map(|(adapter, w)| super::lora::QwenImageLoraSpec {
1357                adapter: adapter.as_ref(),
1358                scale: w.scale,
1359                path_hash: super::lora::lora_path_hash(&w.path),
1360            })
1361            .collect();
1362
1363        let path_refs: Vec<&std::path::Path> = xformer_paths.iter().map(|p| p.as_path()).collect();
1364        let tensors = unsafe { candle_core::safetensors::MmapedSafetensors::multi(&path_refs)? };
1365        let inner: Box<dyn candle_nn::var_builder::SimpleBackend> = if is_fp8 {
1366            // FP8 path needs the `NativeFp8Backend` so F8E4M3 weights
1367            // stay F8E4M3 in VRAM; the LoRA wrapper merges deltas in
1368            // F32 and the per-layer dequant in `QwenLinear::Fp8::forward`
1369            // sees pre-merged weights as expected.
1370            self.base
1371                .progress
1372                .info("Detected FP8 safetensors — loading with LoRA-merging wrapper");
1373            Box::new(crate::weight_loader::NativeFp8Backend::from_mmap(tensors))
1374        } else {
1375            // candle's `MmapedSafetensors` implements `SimpleBackend`
1376            // directly; use it as the inner layer of the LoRA wrapper.
1377            Box::new(tensors)
1378        };
1379
1380        let wrapped =
1381            super::lora::wrap_backend_with_lora(inner, &specs, &self.base.progress, None)?;
1382
1383        let target_dtype = if is_fp8 { DType::BF16 } else { dtype };
1384        Ok(candle_nn::VarBuilder::from_backend(
1385            wrapped,
1386            target_dtype,
1387            device.clone(),
1388        ))
1389    }
1390
1391    /// Load VAE from disk.
1392    fn load_vae(&self, device: &Device, dtype: DType) -> Result<QwenImageVae> {
1393        let vb = self.load_vae_var_builder(device, dtype)?;
1394        Ok(QwenImageVae::from_var_builder(vb, device, dtype)?)
1395    }
1396
1397    fn load_vae_cpu_tensors(&self) -> Result<Option<Arc<HashMap<String, Tensor>>>> {
1398        let Some(shared_pool) = &self.shared_pool else {
1399            return Ok(None);
1400        };
1401        shared_pool
1402            .lock()
1403            .unwrap()
1404            .load_safetensors_cpu_tensors(std::slice::from_ref(&self.base.paths.vae))
1405    }
1406
1407    fn load_vae_var_builder<'a>(
1408        &self,
1409        device: &Device,
1410        dtype: DType,
1411    ) -> Result<candle_nn::VarBuilder<'a>> {
1412        if let Some(tensors) = self.load_vae_cpu_tensors()? {
1413            return Ok(encoders::park::varbuilder_from_parked(
1414                tensors.as_ref(),
1415                dtype,
1416                device,
1417            ));
1418        }
1419
1420        crate::weight_loader::load_safetensors_with_progress(
1421            std::slice::from_ref(&self.base.paths.vae),
1422            dtype,
1423            device,
1424            "Qwen-Image VAE",
1425            &self.base.progress,
1426        )
1427    }
1428
1429    /// Load text encoder from disk.
1430    ///
1431    /// FP8 text encoders are loaded on GPU with BF16 dtype — candle's CUDA cast
1432    /// kernel handles F8E4M3→BF16 conversion during tensor loading.
1433    fn resolve_text_encoder_source(
1434        &self,
1435        gpu_device: &Device,
1436        free_vram: u64,
1437        usage: Qwen2TextEncoderUsage,
1438    ) -> Result<ResolvedQwen2TextEncoder> {
1439        let preference = std::env::var("MOLD_QWEN2_VARIANT").ok();
1440        self.resolve_text_encoder_source_with_preference(
1441            gpu_device,
1442            free_vram,
1443            usage,
1444            preference.as_deref(),
1445        )
1446    }
1447
1448    fn resolve_text_encoder_source_with_preference(
1449        &self,
1450        gpu_device: &Device,
1451        free_vram: u64,
1452        usage: Qwen2TextEncoderUsage,
1453        preference: Option<&str>,
1454    ) -> Result<ResolvedQwen2TextEncoder> {
1455        let is_cuda = gpu_device.is_cuda();
1456        let is_metal = gpu_device.is_metal();
1457        let bf16_size_bytes = self
1458            .base
1459            .paths
1460            .text_encoder_files
1461            .iter()
1462            .filter_map(|p| std::fs::metadata(p).ok())
1463            .map(|m| m.len())
1464            .sum();
1465        if self.is_edit_family() {
1466            let mut resolved = Self::choose_text_encoder_source(
1467                preference,
1468                is_cuda,
1469                is_metal,
1470                free_vram,
1471                bf16_size_bytes,
1472                Qwen2TextEncoderUsage::Resident,
1473            )?;
1474            resolved.vision_paths = self.base.paths.text_encoder_files.clone();
1475            if resolved.is_gguf {
1476                let variant = mold_core::manifest::find_qwen2_vl_variant(&resolved.variant_label)
1477                    .ok_or_else(|| {
1478                    anyhow::anyhow!("unknown Qwen2.5-VL variant '{}'", resolved.variant_label)
1479                })?;
1480                resolved.paths = vec![
1481                    crate::encoders::variant_resolution::resolve_qwen2_vl_gguf_path(
1482                        &self.base.progress,
1483                        variant,
1484                    )?,
1485                ];
1486            } else {
1487                resolved.paths = self.base.paths.text_encoder_files.clone();
1488            }
1489            return Ok(resolved);
1490        }
1491        let mut resolved = Self::choose_text_encoder_source(
1492            preference,
1493            is_cuda,
1494            is_metal,
1495            free_vram,
1496            bf16_size_bytes,
1497            usage,
1498        )?;
1499
1500        if resolved.is_gguf {
1501            let variant = mold_core::manifest::find_qwen2_vl_variant(&resolved.variant_label)
1502                .ok_or_else(|| {
1503                    anyhow::anyhow!("unknown Qwen2.5-VL variant '{}'", resolved.variant_label)
1504                })?;
1505            resolved.paths = vec![
1506                crate::encoders::variant_resolution::resolve_qwen2_vl_gguf_path(
1507                    &self.base.progress,
1508                    variant,
1509                )?,
1510            ];
1511        } else {
1512            resolved.paths = self.base.paths.text_encoder_files.clone();
1513        }
1514        resolved.vision_paths = vec![];
1515
1516        match preference {
1517            Some(tag) if tag != "auto" && tag != "bf16" => self.base.progress.info(&format!(
1518                "Using quantized Qwen2.5-VL {} ({}) on {} (explicit)",
1519                resolved.variant_label,
1520                fmt_gb(resolved.size_bytes),
1521                if resolved.auto_use_gpu { "GPU" } else { "CPU" },
1522            )),
1523            Some("bf16") => {}
1524            _ if is_metal && resolved.is_gguf && resolved.variant_label == "q6" => self
1525                .base
1526                .progress
1527                .info(&format!(
1528                    "Metal auto mode selected quantized Qwen2.5-VL {} ({}) for lower memory pressure",
1529                    resolved.variant_label,
1530                    fmt_gb(resolved.size_bytes),
1531                )),
1532            _ if is_metal && resolved.is_gguf => self.base.progress.info(&format!(
1533                "Metal auto mode forcing quantized Qwen2.5-VL {} ({}) to avoid BF16 memory pressure",
1534                resolved.variant_label,
1535                fmt_gb(resolved.size_bytes),
1536            )),
1537            _ if is_cuda && resolved.is_gguf && resolved.auto_use_gpu => self.base.progress.info(
1538                &format!(
1539                    "CUDA auto mode selected quantized Qwen2.5-VL {} ({}) on GPU",
1540                    resolved.variant_label,
1541                    fmt_gb(resolved.size_bytes),
1542                ),
1543            ),
1544            _ if is_cuda && resolved.is_gguf => self.base.progress.info(&format!(
1545                "CUDA auto mode selected quantized Qwen2.5-VL {} ({}) on CPU to avoid large BF16 host residency",
1546                resolved.variant_label,
1547                fmt_gb(resolved.size_bytes),
1548            )),
1549            _ => {}
1550        }
1551
1552        Ok(resolved)
1553    }
1554
1555    fn can_keep_transformer_hot_for_vae(loaded: &LoadedQwenImage) -> bool {
1556        Self::qwen_transformer_can_stay_hot_for_vae(
1557            loaded.device.is_cuda(),
1558            loaded.vae_device.is_cuda(),
1559            matches!(
1560                loaded.transformer.as_ref(),
1561                Some(QwenImageTransformer::Quantized(_))
1562            ),
1563        )
1564    }
1565
1566    fn qwen_transformer_can_stay_hot_for_vae(
1567        transformer_is_cuda: bool,
1568        vae_is_cuda: bool,
1569        transformer_is_quantized: bool,
1570    ) -> bool {
1571        transformer_is_cuda && vae_is_cuda && transformer_is_quantized
1572    }
1573
1574    fn decode_vae_gpu_only(
1575        latents: &Tensor,
1576        vae: &QwenImageVae,
1577        vae_device: &Device,
1578        sync_device: &Device,
1579        progress: &ProgressReporter,
1580        prefer_tiled: bool,
1581    ) -> Result<Tensor> {
1582        if vae_device.is_cuda() && prefer_tiled {
1583            progress.info("Selecting tiled GPU VAE decode proactively");
1584            return Self::decode_vae_tiled(latents, vae, vae_device, progress);
1585        }
1586
1587        let decode_latents = latents.to_device(vae_device)?.to_dtype(DType::F32)?;
1588        match vae.decode(&decode_latents) {
1589            Ok(image) => Ok(image),
1590            Err(e) if vae_device.is_cuda() && Self::is_oom_error(&e) => {
1591                progress.info(
1592                    "Resident-transformer VAE decode OOM on GPU — retrying with tiled GPU decode before dropping transformer",
1593                );
1594                sync_device.synchronize()?;
1595                Self::decode_vae_tiled(latents, vae, vae_device, progress)
1596            }
1597            Err(e) => Err(e.into()),
1598        }
1599    }
1600
1601    fn load_text_encoder(
1602        &self,
1603        resolved: &ResolvedQwen2TextEncoder,
1604        tokenizer_path: &std::path::PathBuf,
1605        tokenizer: Arc<Tokenizer>,
1606        device: &Device,
1607        dtype: DType,
1608        preload_weights: bool,
1609    ) -> Result<encoders::qwen2_text::Qwen2TextEncoder> {
1610        if resolved.is_gguf {
1611            if preload_weights {
1612                encoders::qwen2_text::Qwen2TextEncoder::load_gguf_with_tokenizer(
1613                    &resolved.paths[0],
1614                    tokenizer_path,
1615                    Some(tokenizer),
1616                    device,
1617                    dtype,
1618                    &resolved.vision_paths,
1619                    &self.base.progress,
1620                )
1621            } else {
1622                encoders::qwen2_text::Qwen2TextEncoder::prepare_gguf_with_tokenizer(
1623                    &resolved.paths[0],
1624                    tokenizer_path,
1625                    Some(tokenizer),
1626                    device,
1627                    dtype,
1628                    &resolved.vision_paths,
1629                )
1630            }
1631        } else {
1632            let is_fp8 = text_encoder_is_fp8(&resolved.paths);
1633            if is_fp8 {
1634                self.base
1635                    .progress
1636                    .info("Detected FP8 text encoder — loading as BF16 on GPU");
1637            }
1638            if preload_weights {
1639                encoders::qwen2_text::Qwen2TextEncoder::load_bf16_with_tokenizer(
1640                    &resolved.paths,
1641                    tokenizer_path,
1642                    Some(tokenizer),
1643                    device,
1644                    dtype,
1645                    self.is_edit_family(),
1646                    &self.base.progress,
1647                )
1648            } else {
1649                encoders::qwen2_text::Qwen2TextEncoder::prepare_bf16_with_tokenizer(
1650                    &resolved.paths,
1651                    tokenizer_path,
1652                    Some(tokenizer),
1653                    device,
1654                    dtype,
1655                    self.is_edit_family(),
1656                )
1657            }
1658        }
1659    }
1660
1661    /// Resolve text encoder device placement and optional CPU staging.
1662    fn resolve_text_encoder_plan(
1663        &self,
1664        gpu_device: &Device,
1665        resolved: &ResolvedQwen2TextEncoder,
1666        free_vram: u64,
1667    ) -> (Qwen2TextEncoderPlan, String) {
1668        let is_cuda = gpu_device.is_cuda();
1669        let is_metal = gpu_device.is_metal();
1670        let plan = Self::qwen2_text_encoder_plan_for_mode(
1671            Qwen2TextEncoderMode::from_env(),
1672            is_cuda,
1673            is_metal,
1674            resolved,
1675        );
1676        let label = if plan.use_gpu { "GPU" } else { "CPU" };
1677        if plan.use_cpu_staging {
1678            self.base
1679                .progress
1680                .info("Qwen2.5 text encoder on GPU with CPU staging after encoding");
1681        } else if !plan.use_gpu {
1682            if resolved.is_gguf {
1683                self.base.progress.info(&format!(
1684                    "Qwen2.5 text encoder on CPU ({} variant {}, {} free)",
1685                    resolved.variant_label,
1686                    fmt_gb(resolved.size_bytes),
1687                    fmt_gb(free_vram),
1688                ));
1689            } else if is_metal || is_cuda {
1690                self.base.progress.info(&format!(
1691                    "Qwen2.5 text encoder on CPU ({} free < {} threshold)",
1692                    fmt_gb(free_vram),
1693                    fmt_gb(QWEN2_FP16_VRAM_THRESHOLD),
1694                ));
1695            }
1696        }
1697        (plan, label.to_string())
1698    }
1699
1700    fn qwen2_text_encoder_plan_for_mode(
1701        mode: Qwen2TextEncoderMode,
1702        is_cuda: bool,
1703        is_metal: bool,
1704        resolved: &ResolvedQwen2TextEncoder,
1705    ) -> Qwen2TextEncoderPlan {
1706        match mode {
1707            Qwen2TextEncoderMode::Gpu => Qwen2TextEncoderPlan {
1708                use_gpu: is_cuda || is_metal,
1709                use_cpu_staging: false,
1710            },
1711            Qwen2TextEncoderMode::CpuStage => Qwen2TextEncoderPlan {
1712                use_gpu: is_cuda || is_metal,
1713                use_cpu_staging: is_cuda || is_metal,
1714            },
1715            Qwen2TextEncoderMode::Cpu => Qwen2TextEncoderPlan {
1716                use_gpu: false,
1717                use_cpu_staging: false,
1718            },
1719            Qwen2TextEncoderMode::Auto => Qwen2TextEncoderPlan {
1720                use_gpu: resolved.auto_use_gpu,
1721                use_cpu_staging: is_metal && resolved.auto_use_gpu && !resolved.is_gguf,
1722            },
1723        }
1724    }
1725
1726    /// Load all model components (Eager mode).
1727    ///
1728    /// On error, `self.base.loaded` remains `None` — all components are assembled into
1729    /// local variables and only stored in `self.base.loaded` on success, so partial loads
1730    /// cannot leave the engine in an inconsistent state.
1731    pub fn load(&mut self) -> Result<()> {
1732        if self.base.loaded.is_some() {
1733            return Ok(());
1734        }
1735
1736        // Sequential mode defers loading to generate_sequential()
1737        if self.base.load_strategy == LoadStrategy::Sequential {
1738            return Ok(());
1739        }
1740
1741        tracing::info!(model = %self.base.model_name, "loading Qwen-Image model components...");
1742
1743        let text_tokenizer_path = self.validate_paths()?;
1744        let transformer_ref = effective_device_ref(
1745            self.pending_placement.as_ref(),
1746            |adv| Some(adv.transformer),
1747            false,
1748        );
1749        let device = crate::device::resolve_device(Some(transformer_ref), || {
1750            crate::device::create_device(self.base.gpu_ordinal, &self.base.progress)
1751        })?;
1752        let transformer_cfg = self.transformer_config();
1753        let transformer_is_quantized = self.detect_is_quantized();
1754        // FP8 safetensors are loaded as BF16 via CPU (candle CUDA kernel bug
1755        // prevents direct F8E4M3→BF16 on GPU; CPU cast works fine). All paths
1756        // use BF16 as runtime dtype since the model trains and computes in BF16.
1757        let dtype = crate::engine::gpu_dtype(&device);
1758
1759        // Load transformer
1760        let xformer_paths = self.transformer_paths();
1761        let xformer_label = if transformer_is_quantized {
1762            "Loading Qwen-Image transformer (quantized)".to_string()
1763        } else {
1764            format!(
1765                "Loading Qwen-Image transformer ({} shards)",
1766                xformer_paths.len()
1767            )
1768        };
1769        self.base.progress.stage_start(&xformer_label);
1770        let xformer_start = Instant::now();
1771        let transformer = self.load_transformer(
1772            &device,
1773            dtype,
1774            &transformer_cfg,
1775            QWEN_NATIVE_WIDTH,
1776            QWEN_NATIVE_HEIGHT,
1777        )?;
1778        self.base
1779            .progress
1780            .stage_done(&xformer_label, xformer_start.elapsed());
1781        tracing::info!("Qwen-Image transformer loaded");
1782
1783        // Decide device placement for VAE and text encoder.
1784        // Log raw, budget against the reserve-adjusted reading.
1785        let free_raw = free_vram_bytes(self.base.gpu_ordinal).unwrap_or(0);
1786        let free = usable_free_vram_bytes(self.base.gpu_ordinal).unwrap_or(0);
1787        let is_cuda = device.is_cuda();
1788        let is_metal = device.is_metal();
1789        if free_raw > 0 {
1790            self.base.progress.info(&format!(
1791                "Free VRAM after transformer: {}",
1792                fmt_gb(free_raw)
1793            ));
1794        }
1795
1796        let vae_on_gpu = should_use_gpu(is_cuda, is_metal, free, VAE_DECODE_VRAM_THRESHOLD);
1797        let vae_ref =
1798            effective_device_ref(self.pending_placement.as_ref(), |adv| Some(adv.vae), false);
1799        let vae_device = crate::device::resolve_device(Some(vae_ref), || {
1800            Ok(if vae_on_gpu {
1801                device.clone()
1802            } else {
1803                Device::Cpu
1804            })
1805        })?;
1806        let vae_on_gpu = !vae_device.is_cpu();
1807        // Always decode in F32 — BF16 convolutions accumulate quantization noise across
1808        // the 4 upsampling blocks, producing visible grain. Matches diffusers' force_upcast.
1809        let vae_dtype = DType::F32;
1810        let vae_device_label = if vae_on_gpu { "GPU" } else { "CPU" };
1811
1812        // Load VAE
1813        let vae_label = format!("Loading Qwen-Image VAE ({}, F32)", vae_device_label);
1814        self.base.progress.stage_start(&vae_label);
1815        let vae_start = Instant::now();
1816        let vae = self.load_vae(&vae_device, vae_dtype)?;
1817        self.base
1818            .progress
1819            .stage_done(&vae_label, vae_start.elapsed());
1820
1821        // Load text encoder
1822        let resolved_text_encoder =
1823            self.resolve_text_encoder_source(&device, free, Qwen2TextEncoderUsage::Resident)?;
1824        let (te_plan, te_auto_device_label) =
1825            self.resolve_text_encoder_plan(&device, &resolved_text_encoder, free);
1826        let qwen_ref = effective_device_ref(self.pending_placement.as_ref(), |adv| adv.qwen, true);
1827        let auto_te_device = if te_plan.use_gpu {
1828            device.clone()
1829        } else {
1830            Device::Cpu
1831        };
1832        let te_device =
1833            crate::device::resolve_device(Some(qwen_ref), || Ok(auto_te_device.clone()))?;
1834        let te_use_gpu = !te_device.is_cpu();
1835        let te_device_label: String = if te_use_gpu == te_plan.use_gpu {
1836            te_auto_device_label
1837        } else if te_use_gpu {
1838            "GPU".into()
1839        } else {
1840            "CPU".into()
1841        };
1842        let te_dtype = Self::text_encoder_load_dtype(te_use_gpu, dtype);
1843
1844        let preload_text_encoder = self.should_preload_text_encoder();
1845        let te_label = if resolved_text_encoder.is_gguf {
1846            if preload_text_encoder {
1847                format!(
1848                    "Loading Qwen2.5 text encoder ({} GGUF, {})",
1849                    resolved_text_encoder.variant_label, te_device_label
1850                )
1851            } else {
1852                format!(
1853                    "Preparing Qwen2.5 text encoder ({} GGUF, {})",
1854                    resolved_text_encoder.variant_label, te_device_label
1855                )
1856            }
1857        } else if preload_text_encoder {
1858            format!(
1859                "Loading Qwen2.5 text encoder ({} shards, {})",
1860                resolved_text_encoder.paths.len(),
1861                te_device_label,
1862            )
1863        } else {
1864            format!(
1865                "Preparing Qwen2.5 text encoder ({} shards, {})",
1866                resolved_text_encoder.paths.len(),
1867                te_device_label,
1868            )
1869        };
1870        self.base.progress.stage_start(&te_label);
1871        let te_start = Instant::now();
1872        let text_tokenizer = self.load_text_tokenizer(&text_tokenizer_path)?;
1873        let text_encoder = self.load_text_encoder(
1874            &resolved_text_encoder,
1875            &text_tokenizer_path,
1876            text_tokenizer,
1877            &te_device,
1878            te_dtype,
1879            preload_text_encoder,
1880        )?;
1881        self.base.progress.stage_done(&te_label, te_start.elapsed());
1882        if preload_text_encoder {
1883            tracing::info!(device = %te_device_label, "Qwen2.5 text encoder loaded");
1884        } else {
1885            tracing::info!(device = %te_device_label, "Qwen2.5 text encoder prepared for staged loading");
1886        }
1887
1888        self.base.loaded = Some(LoadedQwenImage {
1889            transformer: Some(transformer),
1890            text_encoder,
1891            vae,
1892            vae_path: self.base.paths.vae.clone(),
1893            transformer_cfg,
1894            device,
1895            vae_device,
1896            dtype,
1897        });
1898
1899        tracing::info!(model = %self.base.model_name, "all Qwen-Image components loaded");
1900        Ok(())
1901    }
1902
1903    /// Reload the transformer from disk.
1904    fn reload_transformer(
1905        &self,
1906        loaded: &mut LoadedQwenImage,
1907        width: usize,
1908        height: usize,
1909    ) -> Result<()> {
1910        let transformer = self.load_transformer(
1911            &loaded.device,
1912            loaded.dtype,
1913            &loaded.transformer_cfg,
1914            width,
1915            height,
1916        )?;
1917        loaded.transformer = Some(transformer);
1918        Ok(())
1919    }
1920
1921    /// Generate using sequential loading strategy (load-use-drop each component).
1922    fn generate_sequential(&mut self, req: &GenerateRequest) -> Result<GenerateResponse> {
1923        let text_tokenizer_path = self.validate_paths()?;
1924        let transformer_cfg = self.transformer_config();
1925
1926        let transformer_ref = effective_device_ref(
1927            self.pending_placement.as_ref(),
1928            |adv| Some(adv.transformer),
1929            false,
1930        );
1931        let device = crate::device::resolve_device(Some(transformer_ref), || {
1932            crate::device::create_device(self.base.gpu_ordinal, &self.base.progress)
1933        })?;
1934        let dtype = crate::engine::gpu_dtype(&device);
1935        let transformer_is_quantized = self.detect_is_quantized();
1936
1937        let start = Instant::now();
1938        let seed = req.seed.unwrap_or_else(rand_seed);
1939
1940        let width = req.width as usize;
1941        let height = req.height as usize;
1942        // Reserve-adjusted reading: text-encoder source / placement is a
1943        // budget decision.
1944        let free = usable_free_vram_bytes(self.base.gpu_ordinal).unwrap_or(0);
1945        let resolved_text_encoder =
1946            self.resolve_text_encoder_source(&device, free, Qwen2TextEncoderUsage::Sequential)?;
1947        let (plan, _device_label) =
1948            self.resolve_text_encoder_plan(&device, &resolved_text_encoder, free);
1949        let use_cpu_staging = plan.use_cpu_staging;
1950
1951        tracing::info!(
1952            prompt = %req.prompt,
1953            seed, width, height,
1954            steps = req.steps,
1955            "starting sequential Qwen-Image generation"
1956        );
1957
1958        self.base
1959            .progress
1960            .info("Using sequential loading (load-use-drop) to minimize peak memory");
1961
1962        // --- Phase 1: Text encoding (check cache first to skip encoder load) ---
1963        let use_cfg = req.guidance > 1.0;
1964        let prompt_key = prompt_text_key(&req.prompt);
1965        let uncond_key = prompt_text_key(QWEN_EMPTY_NEGATIVE_PROMPT);
1966        let (prompt_cached, uncond_cached) = {
1967            let mut cache = self.prompt_cache.lock().expect("cache poisoned");
1968            let prompt_cached = cache.get_cloned(&prompt_key);
1969            let uncond_cached = if use_cfg {
1970                cache.get_cloned(&uncond_key)
1971            } else {
1972                None
1973            };
1974            (prompt_cached, uncond_cached)
1975        };
1976        let both_cached = prompt_cached.is_some() && (!use_cfg || uncond_cached.is_some());
1977
1978        let (mut encoder_hidden_states, mut encoder_attention_mask, mut uncond_hs, mut uncond_mask) =
1979            if both_cached {
1980                self.base.progress.cache_hit("prompt conditioning");
1981                let cached = prompt_cached.unwrap();
1982                let restore_device = if use_cpu_staging {
1983                    &Device::Cpu
1984                } else {
1985                    &device
1986                };
1987                let restore_dtype = if use_cpu_staging { DType::F32 } else { dtype };
1988                let (hs, mask) = cached.restore(restore_device, restore_dtype)?;
1989                let (u_hs, u_mask) = if use_cfg {
1990                    let ucached = uncond_cached.unwrap();
1991                    let (u_hs, u_mask) = ucached.restore(restore_device, restore_dtype)?;
1992                    (Some(u_hs), Some(u_mask))
1993                } else {
1994                    (None, None)
1995                };
1996                (hs, mask, u_hs, u_mask)
1997            } else {
1998                let (te_plan, te_auto_device_label) =
1999                    self.resolve_text_encoder_plan(&device, &resolved_text_encoder, free);
2000                let qwen_ref =
2001                    effective_device_ref(self.pending_placement.as_ref(), |adv| adv.qwen, true);
2002                let auto_te_device = if te_plan.use_gpu {
2003                    device.clone()
2004                } else {
2005                    Device::Cpu
2006                };
2007                let te_device =
2008                    crate::device::resolve_device(Some(qwen_ref), || Ok(auto_te_device.clone()))?;
2009                let te_use_gpu = !te_device.is_cpu();
2010                let te_device_label: String = if te_use_gpu == te_plan.use_gpu {
2011                    te_auto_device_label
2012                } else if te_use_gpu {
2013                    "GPU".into()
2014                } else {
2015                    "CPU".into()
2016                };
2017                let te_dtype = Self::text_encoder_load_dtype(te_use_gpu, dtype);
2018
2019                let te_label = if resolved_text_encoder.is_gguf {
2020                    format!(
2021                        "Loading Qwen2.5 text encoder ({} GGUF, {})",
2022                        resolved_text_encoder.variant_label, te_device_label
2023                    )
2024                } else {
2025                    format!(
2026                        "Loading Qwen2.5 text encoder ({} shards, {})",
2027                        resolved_text_encoder.paths.len(),
2028                        te_device_label,
2029                    )
2030                };
2031                if te_plan.use_cpu_staging && device.is_metal() && !resolved_text_encoder.is_gguf {
2032                    self.base.progress.info(
2033                        "Skipping hard preflight for Qwen2.5 text encoder on Metal; sequential mode spills prompt conditioning to CPU after encoding",
2034                    );
2035                } else {
2036                    let te_activation_budget = crate::device::activation_bytes(
2037                        req.width,
2038                        req.height,
2039                        1,
2040                        crate::device::dtype_bytes(te_dtype),
2041                        crate::device::ActivationFamily::SmallTransformer,
2042                    );
2043                    preflight_memory_check(
2044                        "Qwen2.5 text encoder",
2045                        resolved_text_encoder.size_bytes,
2046                        te_activation_budget,
2047                    )?;
2048                }
2049
2050                if let Some(status) = memory_status_string() {
2051                    self.base.progress.info(&status);
2052                }
2053
2054                self.base.progress.stage_start(&te_label);
2055                let te_start = Instant::now();
2056                let text_tokenizer = self.load_text_tokenizer(&text_tokenizer_path)?;
2057                let mut text_encoder = self.load_text_encoder(
2058                    &resolved_text_encoder,
2059                    &text_tokenizer_path,
2060                    text_tokenizer,
2061                    &te_device,
2062                    te_dtype,
2063                    true,
2064                )?;
2065                self.base.progress.stage_done(&te_label, te_start.elapsed());
2066
2067                let (hs, mask) = Self::encode_prompt_cached(
2068                    &self.base.progress,
2069                    &self.prompt_cache,
2070                    &mut text_encoder,
2071                    &req.prompt,
2072                    &device,
2073                    dtype,
2074                )?;
2075                let (hs, mask) = Self::maybe_spill_conditioning(use_cpu_staging, hs, mask)?;
2076
2077                let (u_hs, u_mask) = if use_cfg {
2078                    let (hs, mask) = Self::encode_prompt_cached(
2079                        &self.base.progress,
2080                        &self.prompt_cache,
2081                        &mut text_encoder,
2082                        QWEN_EMPTY_NEGATIVE_PROMPT,
2083                        &device,
2084                        dtype,
2085                    )?;
2086                    let (hs, mask) = Self::maybe_spill_conditioning(use_cpu_staging, hs, mask)?;
2087                    (Some(hs), Some(mask))
2088                } else {
2089                    (None, None)
2090                };
2091
2092                drop(text_encoder);
2093                // Force the backend to release allocator state before transformer load.
2094                device.synchronize()?;
2095                if let Some(status) = crate::device::memory_status_string() {
2096                    if use_cpu_staging {
2097                        self.base.progress.info(&format!(
2098                            "Freed Qwen2.5 text encoder and spilled prompt conditioning to CPU — {status}"
2099                        ));
2100                    } else {
2101                        self.base
2102                            .progress
2103                            .info(&format!("Freed Qwen2.5 text encoder — {status}"));
2104                    }
2105                } else {
2106                    if use_cpu_staging {
2107                        self.base.progress.info(
2108                            "Freed Qwen2.5 text encoder and spilled prompt conditioning to CPU",
2109                        );
2110                    } else {
2111                        self.base.progress.info("Freed Qwen2.5 text encoder");
2112                    }
2113                }
2114
2115                (hs, mask, u_hs, u_mask)
2116            };
2117
2118        if use_cfg {
2119            let ((cond_hs, cond_mask), (neg_hs, neg_mask)) = align_cfg_conditioning(
2120                &encoder_hidden_states,
2121                &encoder_attention_mask,
2122                uncond_hs.as_ref().expect("unconditional prompt missing"),
2123                uncond_mask.as_ref().expect("unconditional mask missing"),
2124            )?;
2125            encoder_hidden_states = cond_hs;
2126            encoder_attention_mask = cond_mask;
2127            uncond_hs = Some(neg_hs);
2128            uncond_mask = Some(neg_mask);
2129        }
2130
2131        // --- Phase 2: Load transformer and denoise ---
2132        let xformer_paths = self.transformer_paths();
2133        let xformer_size: u64 = xformer_paths
2134            .iter()
2135            .filter_map(|p| std::fs::metadata(p).ok())
2136            .map(|m| m.len())
2137            .sum();
2138        let xformer_activation_budget = crate::device::activation_bytes(
2139            req.width,
2140            req.height,
2141            if req.guidance > 1.0 { 2 } else { 1 },
2142            crate::device::dtype_bytes(dtype),
2143            crate::device::ActivationFamily::QwenImageDit,
2144        );
2145        preflight_memory_check(
2146            "Qwen-Image transformer",
2147            xformer_size,
2148            xformer_activation_budget,
2149        )?;
2150
2151        if let Some(status) = memory_status_string() {
2152            self.base.progress.info(&status);
2153        }
2154
2155        let xformer_label = if transformer_is_quantized {
2156            "Loading Qwen-Image transformer (quantized)".to_string()
2157        } else {
2158            format!(
2159                "Loading Qwen-Image transformer ({} shards)",
2160                xformer_paths.len()
2161            )
2162        };
2163        self.base.progress.stage_start(&xformer_label);
2164        let xformer_start = Instant::now();
2165        let transformer = self.load_transformer(&device, dtype, &transformer_cfg, width, height)?;
2166        self.base
2167            .progress
2168            .stage_done(&xformer_label, xformer_start.elapsed());
2169
2170        if use_cpu_staging {
2171            encoder_hidden_states = encoder_hidden_states.to_device(&device)?.to_dtype(dtype)?;
2172            encoder_attention_mask = encoder_attention_mask.to_device(&device)?;
2173            if let Some(hs) = uncond_hs.take() {
2174                uncond_hs = Some(hs.to_device(&device)?.to_dtype(dtype)?);
2175            }
2176            if let Some(mask) = uncond_mask.take() {
2177                uncond_mask = Some(mask.to_device(&device)?);
2178            }
2179            if let Some(status) = memory_status_string() {
2180                self.base.progress.info(&format!(
2181                    "Restored prompt conditioning to GPU for denoising — {status}"
2182                ));
2183            } else {
2184                self.base
2185                    .progress
2186                    .info("Restored prompt conditioning to GPU for denoising");
2187            }
2188        }
2189
2190        // Calculate latent dimensions: image_size / 8 (VAE downsample factor)
2191        let vae_downsample = 8;
2192        let latent_h = height / vae_downsample;
2193        let latent_w = width / vae_downsample;
2194        let is_img2img = req.source_image.is_some();
2195
2196        // For img2img, load VAE early to encode source image before transformer
2197        let (prepared_img2img_latents, inpaint_ctx) = if let Some(ref source_bytes) =
2198            req.source_image
2199        {
2200            // Reserve-adjusted reading drives the encode-device decision.
2201            let free_for_encode = usable_free_vram_bytes(self.base.gpu_ordinal).unwrap_or(0);
2202            let encode_on_gpu = should_use_gpu(
2203                device.is_cuda(),
2204                device.is_metal(),
2205                free_for_encode,
2206                VAE_DECODE_VRAM_THRESHOLD,
2207            );
2208            let encode_device = if encode_on_gpu {
2209                device.clone()
2210            } else {
2211                Device::Cpu
2212            };
2213            let encode_label = if encode_on_gpu { "GPU" } else { "CPU" };
2214
2215            let vae_label = format!("Loading Qwen-Image VAE ({}, F32) for encode", encode_label);
2216            self.base.progress.stage_start(&vae_label);
2217            let vae_start = Instant::now();
2218            let encode_vae = self.load_vae(&encode_device, DType::F32)?;
2219            self.base
2220                .progress
2221                .stage_done(&vae_label, vae_start.elapsed());
2222
2223            let encoded = Self::encode_vae_with_fallback(
2224                source_bytes,
2225                req.width,
2226                req.height,
2227                &encode_vae,
2228                &encode_device,
2229                &device,
2230                &self.base.progress,
2231                || self.load_vae(&Device::Cpu, DType::F32),
2232            )?;
2233            let encoded = encoded.to_device(&device)?.to_dtype(dtype)?;
2234            let start_sigma = QwenImageScheduler::new_img2img(
2235                req.steps as usize,
2236                image_seq_len(latent_h, latent_w, transformer_cfg.patch_size),
2237                req.strength,
2238            )
2239            .0
2240            .initial_sigma();
2241            let prepared = crate::img2img::prepare_flow_match_img2img(
2242                &encoded,
2243                seed,
2244                &[1, 16, latent_h, latent_w],
2245                start_sigma,
2246                req.mask_image.as_deref(),
2247                latent_h,
2248                latent_w,
2249                &device,
2250                dtype,
2251            )?;
2252
2253            // Drop early VAE to free memory before transformer load
2254            drop(encode_vae);
2255            device.synchronize()?;
2256
2257            tracing::info!(
2258                strength = req.strength,
2259                "img2img: encoded source image to latents"
2260            );
2261
2262            (Some(prepared.initial_latents), prepared.inpaint_ctx)
2263        } else {
2264            (None, None)
2265        };
2266
2267        let image_seq_len = image_seq_len(latent_h, latent_w, transformer_cfg.patch_size);
2268        let (mut scheduler, num_steps) = if is_img2img {
2269            QwenImageScheduler::new_img2img(req.steps as usize, image_seq_len, req.strength)
2270        } else {
2271            let sched = QwenImageScheduler::new(req.steps as usize, image_seq_len);
2272            let n = sched.num_steps();
2273            (sched, n)
2274        };
2275
2276        // Build initial latents
2277        let mut latents = if let Some(initial) = &prepared_img2img_latents {
2278            initial.clone()
2279        } else {
2280            let noise =
2281                crate::engine::seeded_randn(seed, &[1, 16, latent_h, latent_w], &device, dtype)?;
2282            (noise * scheduler.initial_sigma())?
2283        };
2284
2285        let denoise_label = format!("Denoising ({} steps)", num_steps);
2286        self.base.progress.stage_start(&denoise_label);
2287        let denoise_start = Instant::now();
2288
2289        if std::env::var_os("MOLD_QWEN_DEBUG").is_some() {
2290            eprintln!(
2291                "[qwen-debug] cfg={} guidance={:.1} image_seq_len={} sigmas[0]={:.4} sigmas[last]={:.4} img2img={}",
2292                use_cfg,
2293                req.guidance,
2294                image_seq_len,
2295                scheduler.sigmas[0],
2296                scheduler.sigmas[scheduler.sigmas.len() - 1],
2297                is_img2img,
2298            );
2299        }
2300
2301        let use_batched_cfg = use_cfg && transformer.supports_cfg_batching();
2302        if use_cfg && !use_batched_cfg {
2303            self.base.progress.info(
2304                "Low-memory quantized Qwen CUDA path detected — disabling CFG batching to reduce peak CUDA memory",
2305            );
2306        }
2307
2308        // Pre-batch CFG inputs when the selected transformer path can handle the
2309        // extra batch dimension without exceeding peak memory.
2310        let (batched_hs, batched_mask) = if use_batched_cfg {
2311            let hs = Tensor::cat(&[&encoder_hidden_states, uncond_hs.as_ref().unwrap()], 0)?;
2312            let mask = Tensor::cat(&[&encoder_attention_mask, uncond_mask.as_ref().unwrap()], 0)?;
2313            (hs, mask)
2314        } else {
2315            (
2316                encoder_hidden_states.clone(),
2317                encoder_attention_mask.clone(),
2318            )
2319        };
2320
2321        for step in 0..num_steps {
2322            let step_start = Instant::now();
2323            let t = scheduler.current_timestep();
2324            let noise_pred = if use_cfg {
2325                let (cond_pred, uncond_pred) = if use_batched_cfg {
2326                    let t_tensor =
2327                        Tensor::from_vec(vec![t as f32; 2], (2,), &device)?.to_dtype(dtype)?;
2328                    let batched_latents = Tensor::cat(&[&latents, &latents], 0)?;
2329                    let batched_pred = transformer.forward(
2330                        &batched_latents,
2331                        &t_tensor,
2332                        &batched_hs,
2333                        &batched_mask,
2334                    )?;
2335                    (batched_pred.narrow(0, 0, 1)?, batched_pred.narrow(0, 1, 1)?)
2336                } else {
2337                    let t_tensor =
2338                        Tensor::from_vec(vec![t as f32], (1,), &device)?.to_dtype(dtype)?;
2339                    (
2340                        transformer.forward(
2341                            &latents,
2342                            &t_tensor,
2343                            &encoder_hidden_states,
2344                            &encoder_attention_mask,
2345                        )?,
2346                        transformer.forward(
2347                            &latents,
2348                            &t_tensor,
2349                            uncond_hs.as_ref().unwrap(),
2350                            uncond_mask.as_ref().unwrap(),
2351                        )?,
2352                    )
2353                };
2354                if step == 0 {
2355                    Self::debug_tensor_stats("cond_pred[0]", &cond_pred);
2356                    Self::debug_tensor_stats("uncond_pred[0]", &uncond_pred);
2357                }
2358                // CFG in F32 to avoid BF16 cancellation error, then norm rescale
2359                // to match diffusers' Qwen-Image pipeline.
2360                let cond_f32 = cond_pred.to_dtype(DType::F32)?;
2361                let uncond_f32 = uncond_pred.to_dtype(DType::F32)?;
2362                let comb = (&uncond_f32 + ((&cond_f32 - &uncond_f32)? * req.guidance)?)?;
2363                let cond_norm = cond_f32.sqr()?.sum_keepdim(1)?.sqrt()?;
2364                let comb_norm = comb.sqr()?.sum_keepdim(1)?.sqrt()?.clamp(1e-8, f64::MAX)?;
2365                let rescaled = comb.broadcast_mul(&(cond_norm / comb_norm)?)?;
2366                rescaled.to_dtype(dtype)?
2367            } else {
2368                let t_tensor = Tensor::from_vec(vec![t as f32], (1,), &device)?.to_dtype(dtype)?;
2369                transformer.forward(
2370                    &latents,
2371                    &t_tensor,
2372                    &encoder_hidden_states,
2373                    &encoder_attention_mask,
2374                )?
2375            };
2376            if step == 0 || step == num_steps / 2 || step == num_steps - 1 {
2377                Self::debug_tensor_stats(&format!("noise_pred[{step}]"), &noise_pred);
2378                Self::debug_tensor_stats(&format!("latents[{step}]"), &latents);
2379            }
2380            if step == 0 {
2381                Self::validate_qwen_tensor_boundary("noise_pred[0]", &noise_pred)?;
2382            }
2383            latents = scheduler.step(&noise_pred, &latents)?;
2384            if step == num_steps - 1 {
2385                Self::validate_qwen_tensor_boundary("latents_final", &latents)?;
2386            }
2387
2388            // Inpainting: blend preserved regions back at current noise level
2389            if let Some(ref ctx) = inpaint_ctx {
2390                latents = crate::img2img::apply_flow_match_inpaint(
2391                    &latents,
2392                    ctx,
2393                    scheduler.sigmas[step + 1],
2394                )?;
2395            }
2396
2397            if std::env::var_os("MOLD_QWEN_DEBUG").is_some() {
2398                let n = latents
2399                    .ne(&latents)?
2400                    .to_dtype(candle_core::DType::U32)?
2401                    .sum_all()?
2402                    .to_scalar::<u32>()?;
2403                if n > 0 {
2404                    eprintln!(
2405                        "[qwen-nan] NaN in latents AFTER step {step}: {n}/{}",
2406                        latents.elem_count()
2407                    );
2408                }
2409            }
2410            self.base.progress.emit(ProgressEvent::DenoiseStep {
2411                step: step + 1,
2412                total: num_steps,
2413                elapsed: step_start.elapsed(),
2414            });
2415        }
2416
2417        self.base
2418            .progress
2419            .stage_done(&denoise_label, denoise_start.elapsed());
2420
2421        // Drop transformer and embeddings
2422        drop(transformer);
2423        drop(encoder_hidden_states);
2424        drop(encoder_attention_mask);
2425        drop(uncond_hs);
2426        drop(uncond_mask);
2427        device.synchronize()?;
2428        self.base.progress.info("Freed Qwen-Image transformer");
2429
2430        // --- Phase 3: Load VAE and decode ---
2431        if let Some(status) = memory_status_string() {
2432            self.base.progress.info(&status);
2433        }
2434
2435        // Reserve-adjusted reading: VAE placement is a budget decision.
2436        let free_for_vae = usable_free_vram_bytes(self.base.gpu_ordinal).unwrap_or(0);
2437        let vae_on_gpu = should_use_gpu(
2438            device.is_cuda(),
2439            device.is_metal(),
2440            free_for_vae,
2441            VAE_DECODE_VRAM_THRESHOLD,
2442        );
2443        let vae_ref =
2444            effective_device_ref(self.pending_placement.as_ref(), |adv| Some(adv.vae), false);
2445        let vae_device = crate::device::resolve_device(Some(vae_ref), || {
2446            Ok(if vae_on_gpu {
2447                device.clone()
2448            } else {
2449                Device::Cpu
2450            })
2451        })?;
2452        let vae_on_gpu = !vae_device.is_cpu();
2453        // Always decode in F32 — BF16 convolutions accumulate quantization noise across
2454        // the 4 upsampling blocks, producing visible grain. Matches diffusers' force_upcast.
2455        let vae_dtype = DType::F32;
2456        let vae_device_label = if vae_on_gpu { "GPU" } else { "CPU" };
2457
2458        let vae_label = format!("Loading Qwen-Image VAE ({}, F32)", vae_device_label);
2459        self.base.progress.stage_start(&vae_label);
2460        let vae_start = Instant::now();
2461        let vae = self.load_vae(&vae_device, vae_dtype)?;
2462        self.base
2463            .progress
2464            .stage_done(&vae_label, vae_start.elapsed());
2465
2466        self.base.progress.stage_start("VAE decode");
2467        let vae_decode_start = Instant::now();
2468        let free_for_decode = usable_free_vram_bytes(self.base.gpu_ordinal).unwrap_or(0);
2469        let prefer_tiled = Self::should_proactively_tile_vae_decode(
2470            req.width,
2471            req.height,
2472            vae_device.is_cuda(),
2473            free_for_decode,
2474        );
2475
2476        let image = Self::decode_vae_with_fallback(
2477            &latents,
2478            &vae,
2479            &vae_device,
2480            &device,
2481            &self.base.progress,
2482            prefer_tiled,
2483            || self.load_vae(&Device::Cpu, DType::F32),
2484        )?;
2485        Self::validate_qwen_tensor_boundary("image_pre_postprocess", &image)?;
2486        Self::debug_tensor_stats("image_pre_postprocess", &image);
2487        let image = postprocess_image(&image)?;
2488        let post_stats = Self::validate_qwen_tensor_boundary("image_postprocess", &image)?;
2489        Self::debug_tensor_stats("image_postprocess", &image);
2490        let image = image.i(0)?;
2491        if Self::near_black_image_stats(post_stats) {
2492            self.base.progress.info(
2493                "Qwen diagnostic: decoded image is near-black after VAE postprocess; inspect MOLD_QWEN_DEBUG tensor stats to separate denoise math from VAE decode",
2494            );
2495            tracing::warn!(
2496                min = post_stats.min,
2497                max = post_stats.max,
2498                mean = post_stats.mean,
2499                "Qwen decoded image is near-black after VAE postprocess"
2500            );
2501        }
2502
2503        self.base
2504            .progress
2505            .stage_done("VAE decode", vae_decode_start.elapsed());
2506
2507        let output_metadata = build_output_metadata(req, seed, None);
2508        let image_bytes = encode_image(
2509            &image,
2510            req.resolved_output_format(),
2511            req.width,
2512            req.height,
2513            output_metadata.as_ref(),
2514        )?;
2515
2516        let generation_time_ms = start.elapsed().as_millis() as u64;
2517        tracing::info!(
2518            generation_time_ms,
2519            seed,
2520            "sequential Qwen-Image generation complete"
2521        );
2522
2523        Ok(GenerateResponse {
2524            images: vec![ImageData {
2525                data: image_bytes,
2526                format: req.resolved_output_format(),
2527                width: req.width,
2528                height: req.height,
2529                index: 0,
2530            }],
2531            generation_time_ms,
2532            model: req.model.clone(),
2533            seed_used: seed,
2534            video: None,
2535            gpu: None,
2536        })
2537    }
2538
2539    fn generate_edit_loaded(&mut self, req: &GenerateRequest) -> Result<GenerateResponse> {
2540        let progress = &self.base.progress;
2541        let start = Instant::now();
2542
2543        let loaded_ref = self
2544            .base
2545            .loaded
2546            .as_ref()
2547            .ok_or_else(|| anyhow::anyhow!("model not loaded"))?;
2548        let needs_reload = loaded_ref.transformer.is_none();
2549        if needs_reload {
2550            let mut loaded_mut = self
2551                .base
2552                .loaded
2553                .take()
2554                .ok_or_else(|| anyhow::anyhow!("model not loaded"))?;
2555            progress.stage_start("Reloading Qwen-Image transformer");
2556            let reload_start = Instant::now();
2557            self.reload_transformer(&mut loaded_mut, req.width as usize, req.height as usize)?;
2558            progress.stage_done("Reloading Qwen-Image transformer", reload_start.elapsed());
2559            self.base.loaded = Some(loaded_mut);
2560        }
2561
2562        let is_edit_family = self.is_edit_family();
2563        let loaded = self
2564            .base
2565            .loaded
2566            .as_mut()
2567            .ok_or_else(|| anyhow::anyhow!("model not loaded"))?;
2568        let seed = req.seed.unwrap_or_else(rand_seed);
2569        let width = req.width as usize;
2570        let height = req.height as usize;
2571        let edit_images = req
2572            .edit_images
2573            .as_ref()
2574            .ok_or_else(|| anyhow::anyhow!("qwen-image-edit requires edit_images"))?;
2575        let use_cfg = req.guidance > 1.0;
2576        let negative_prompt = req
2577            .negative_prompt
2578            .as_deref()
2579            .unwrap_or(QWEN_EMPTY_NEGATIVE_PROMPT);
2580        let formatted_prompt = Self::qwen_image_edit_prompt(&req.prompt, edit_images.len());
2581        let formatted_negative = Self::qwen_image_edit_prompt(negative_prompt, edit_images.len());
2582
2583        tracing::info!(
2584            prompt = %req.prompt,
2585            seed,
2586            width,
2587            height,
2588            steps = req.steps,
2589            edit_images = edit_images.len(),
2590            "starting Qwen-Image edit generation"
2591        );
2592
2593        if loaded.text_encoder.model.is_none() {
2594            let label = if loaded.text_encoder.is_parked() {
2595                "Unparking Qwen2.5 encoder (CPU→GPU)"
2596            } else {
2597                "Reloading Qwen2.5 encoder"
2598            };
2599            progress.stage_start(label);
2600            let reload_start = Instant::now();
2601            if loaded.text_encoder.is_parked() {
2602                loaded.text_encoder.unpark_to_gpu(progress)?;
2603            } else {
2604                loaded.text_encoder.reload(progress)?;
2605            }
2606            progress.stage_done(label, reload_start.elapsed());
2607        }
2608
2609        progress.stage_start("Encoding prompt (Qwen2.5 edit)");
2610        let encode_start = Instant::now();
2611        let (encoder_hidden_states, encoder_attention_mask, _) =
2612            loaded.text_encoder.encode_formatted_multimodal(
2613                &formatted_prompt,
2614                edit_images,
2615                &loaded.device,
2616                loaded.dtype,
2617            )?;
2618        progress.stage_done("Encoding prompt (Qwen2.5 edit)", encode_start.elapsed());
2619        let (encoder_hidden_states, encoder_attention_mask, uncond_hs, uncond_mask) = if use_cfg {
2620            progress.stage_start("Encoding negative prompt (Qwen2.5 edit)");
2621            let neg_start = Instant::now();
2622            let (hs, mask, _) = loaded.text_encoder.encode_formatted_multimodal(
2623                &formatted_negative,
2624                edit_images,
2625                &loaded.device,
2626                loaded.dtype,
2627            )?;
2628            progress.stage_done(
2629                "Encoding negative prompt (Qwen2.5 edit)",
2630                neg_start.elapsed(),
2631            );
2632            let ((cond_hs, cond_mask), (neg_hs, neg_mask)) = align_cfg_conditioning(
2633                &encoder_hidden_states,
2634                &encoder_attention_mask,
2635                &hs,
2636                &mask,
2637            )?;
2638            (cond_hs, cond_mask, Some(neg_hs), Some(neg_mask))
2639        } else {
2640            (encoder_hidden_states, encoder_attention_mask, None, None)
2641        };
2642
2643        let drop_text_encoder = is_edit_family || loaded.text_encoder.on_gpu;
2644        if drop_text_encoder {
2645            let park_mode = crate::device::keep_te_in_ram()
2646                && !loaded.device.is_metal()
2647                && !loaded.text_encoder.is_quantized;
2648            if park_mode {
2649                loaded.text_encoder.park_to_cpu()?;
2650                tracing::info!(
2651                    on_gpu = loaded.text_encoder.on_gpu,
2652                    "Qwen2.5 text encoder parked to CPU host RAM after edit conditioning"
2653                );
2654            } else {
2655                loaded.text_encoder.drop_weights();
2656                tracing::info!(
2657                    on_gpu = loaded.text_encoder.on_gpu,
2658                    "Qwen2.5 text encoder dropped after edit conditioning"
2659                );
2660            }
2661        }
2662
2663        let mut packed_input_storage = Vec::with_capacity(edit_images.len());
2664        let mut img_shapes = vec![(1usize, height / 16, width / 16)];
2665        progress.stage_start("Encoding edit images (VAE)");
2666        let encode_start = Instant::now();
2667        for image_bytes in edit_images {
2668            let (vae_width, vae_height) =
2669                Self::qwen_image_edit_image_dims(image_bytes, QWEN_IMAGE_EDIT_VAE_AREA)?;
2670            let encoded = Self::encode_vae_with_fallback(
2671                image_bytes,
2672                vae_width,
2673                vae_height,
2674                &loaded.vae,
2675                &loaded.vae_device,
2676                &loaded.device,
2677                progress,
2678                || {
2679                    Ok(QwenImageVae::load(
2680                        &loaded.vae_path,
2681                        &Device::Cpu,
2682                        DType::F32,
2683                        progress,
2684                    )?)
2685                },
2686            )?
2687            .to_device(&loaded.device)?
2688            .to_dtype(loaded.dtype)?;
2689            img_shapes.push((1, encoded.dim(2)? / 2, encoded.dim(3)? / 2));
2690            packed_input_storage.push(Self::pack_latents_4d(&encoded)?);
2691        }
2692        progress.stage_done("Encoding edit images (VAE)", encode_start.elapsed());
2693
2694        let packed_inputs = if packed_input_storage.is_empty() {
2695            None
2696        } else {
2697            let tensors = packed_input_storage.iter().collect::<Vec<_>>();
2698            Some(Tensor::cat(&tensors, 1)?)
2699        };
2700
2701        let noise = crate::engine::seeded_randn(
2702            seed,
2703            &[1, 16, height / 8, width / 8],
2704            &loaded.device,
2705            loaded.dtype,
2706        )?;
2707        let mut scheduler =
2708            QwenImageScheduler::new(req.steps as usize, (height / 16) * (width / 16));
2709        let num_steps = scheduler.num_steps();
2710        let mut latents = Self::pack_latents_4d(&(noise * scheduler.initial_sigma())?)?;
2711        let output_seq_len = latents.dim(1)?;
2712
2713        let denoise_label = format!("Denoising edit ({} steps)", num_steps);
2714        progress.stage_start(&denoise_label);
2715        let denoise_start = Instant::now();
2716
2717        {
2718            let transformer = loaded
2719                .transformer
2720                .as_ref()
2721                .expect("transformer must be loaded for denoising");
2722            let use_batched_cfg = use_cfg && transformer.supports_cfg_batching();
2723            let (batched_hs, batched_mask) = if use_batched_cfg {
2724                let hs = Tensor::cat(&[&encoder_hidden_states, uncond_hs.as_ref().unwrap()], 0)?;
2725                let mask =
2726                    Tensor::cat(&[&encoder_attention_mask, uncond_mask.as_ref().unwrap()], 0)?;
2727                (hs, mask)
2728            } else {
2729                (
2730                    encoder_hidden_states.clone(),
2731                    encoder_attention_mask.clone(),
2732                )
2733            };
2734
2735            for step in 0..num_steps {
2736                let step_start = Instant::now();
2737                let t = scheduler.current_timestep();
2738                let timestep = if use_batched_cfg {
2739                    Tensor::from_vec(vec![t as f32; 2], (2,), &loaded.device)?
2740                        .to_dtype(loaded.dtype)?
2741                } else {
2742                    Tensor::from_vec(vec![t as f32], (1,), &loaded.device)?
2743                        .to_dtype(loaded.dtype)?
2744                };
2745
2746                let latent_model_input = if let Some(ref packed_inputs) = packed_inputs {
2747                    Tensor::cat(&[&latents, packed_inputs], 1)?
2748                } else {
2749                    latents.clone()
2750                };
2751
2752                let noise_pred = if use_cfg {
2753                    let (cond_pred, uncond_pred) = if use_batched_cfg {
2754                        let batched_input =
2755                            Tensor::cat(&[&latent_model_input, &latent_model_input], 0)?;
2756                        let pred = transformer.forward_packed(
2757                            &batched_input,
2758                            &timestep,
2759                            &batched_hs,
2760                            &batched_mask,
2761                            &img_shapes,
2762                        )?;
2763                        (
2764                            pred.narrow(0, 0, 1)?.narrow(1, 0, output_seq_len)?,
2765                            pred.narrow(0, 1, 1)?.narrow(1, 0, output_seq_len)?,
2766                        )
2767                    } else {
2768                        (
2769                            transformer
2770                                .forward_packed(
2771                                    &latent_model_input,
2772                                    &timestep,
2773                                    &encoder_hidden_states,
2774                                    &encoder_attention_mask,
2775                                    &img_shapes,
2776                                )?
2777                                .narrow(1, 0, output_seq_len)?,
2778                            transformer
2779                                .forward_packed(
2780                                    &latent_model_input,
2781                                    &timestep,
2782                                    uncond_hs.as_ref().unwrap(),
2783                                    uncond_mask.as_ref().unwrap(),
2784                                    &img_shapes,
2785                                )?
2786                                .narrow(1, 0, output_seq_len)?,
2787                        )
2788                    };
2789
2790                    let cond_f32 = cond_pred.to_dtype(DType::F32)?;
2791                    let uncond_f32 = uncond_pred.to_dtype(DType::F32)?;
2792                    let comb = (&uncond_f32 + ((&cond_f32 - &uncond_f32)? * req.guidance)?)?;
2793                    let cond_norm = cond_f32.sqr()?.sum_keepdim(D::Minus1)?.sqrt()?;
2794                    let comb_norm = comb
2795                        .sqr()?
2796                        .sum_keepdim(D::Minus1)?
2797                        .sqrt()?
2798                        .clamp(1e-8, f64::MAX)?;
2799                    comb.broadcast_mul(&(cond_norm / comb_norm)?)?
2800                        .to_dtype(loaded.dtype)?
2801                } else {
2802                    transformer
2803                        .forward_packed(
2804                            &latent_model_input,
2805                            &timestep,
2806                            &encoder_hidden_states,
2807                            &encoder_attention_mask,
2808                            &img_shapes,
2809                        )?
2810                        .narrow(1, 0, output_seq_len)?
2811                };
2812
2813                latents = scheduler.step(&noise_pred, &latents)?;
2814                progress.emit(ProgressEvent::DenoiseStep {
2815                    step: step + 1,
2816                    total: num_steps,
2817                    elapsed: step_start.elapsed(),
2818                });
2819            }
2820        }
2821
2822        progress.stage_done(&denoise_label, denoise_start.elapsed());
2823
2824        let latents = Self::unpack_latents_packed(&latents, height / 8, width / 8)?;
2825        let free_for_decode = usable_free_vram_bytes(self.base.gpu_ordinal).unwrap_or(0);
2826        let prefer_tiled = Self::should_proactively_tile_vae_decode(
2827            req.width,
2828            req.height,
2829            loaded.vae_device.is_cuda(),
2830            free_for_decode,
2831        );
2832        let image = Self::decode_vae_with_fallback(
2833            &latents,
2834            &loaded.vae,
2835            &loaded.vae_device,
2836            &loaded.device,
2837            progress,
2838            prefer_tiled,
2839            || {
2840                Ok(QwenImageVae::load(
2841                    &loaded.vae_path,
2842                    &Device::Cpu,
2843                    DType::F32,
2844                    progress,
2845                )?)
2846            },
2847        )?;
2848        let image = postprocess_image(&image)?.i(0)?;
2849        let output_metadata = build_output_metadata(req, seed, None);
2850        let image_bytes = encode_image(
2851            &image,
2852            req.resolved_output_format(),
2853            req.width,
2854            req.height,
2855            output_metadata.as_ref(),
2856        )?;
2857
2858        Ok(GenerateResponse {
2859            images: vec![ImageData {
2860                data: image_bytes,
2861                format: req.resolved_output_format(),
2862                width: req.width,
2863                height: req.height,
2864                index: 0,
2865            }],
2866            generation_time_ms: start.elapsed().as_millis() as u64,
2867            model: req.model.clone(),
2868            seed_used: seed,
2869            video: None,
2870            gpu: None,
2871        })
2872    }
2873}
2874
2875impl QwenImageEngine {
2876    fn generate_inner(&mut self, req: &GenerateRequest) -> Result<GenerateResponse> {
2877        if req.scheduler.is_some() {
2878            tracing::warn!(
2879                "scheduler selection not supported for Qwen-Image (flow-matching), ignoring"
2880            );
2881        }
2882
2883        if self.is_edit_family() {
2884            let sequential = self.base.load_strategy == LoadStrategy::Sequential;
2885            if sequential && self.base.loaded.is_none() {
2886                let original = self.base.load_strategy;
2887                self.base.load_strategy = LoadStrategy::Eager;
2888                let load_result = self.load();
2889                self.base.load_strategy = original;
2890                load_result?;
2891            }
2892            if self.base.loaded.is_none() {
2893                bail!("model not loaded -- call load() first");
2894            }
2895            let result = self.generate_edit_loaded(req);
2896            if sequential {
2897                self.unload();
2898            }
2899            return result;
2900        }
2901
2902        // Sequential mode: load-use-drop each component
2903        if self.base.load_strategy == LoadStrategy::Sequential {
2904            return self.generate_sequential(req);
2905        }
2906
2907        // Eager mode: use pre-loaded components
2908        if self.base.loaded.is_none() {
2909            bail!("model not loaded -- call load() first");
2910        }
2911
2912        let progress = &self.base.progress;
2913        let gpu_ordinal = self.base.gpu_ordinal;
2914        let start = Instant::now();
2915
2916        // Reload transformer if it was dropped after previous VAE decode
2917        let loaded_ref = self
2918            .base
2919            .loaded
2920            .as_ref()
2921            .ok_or_else(|| anyhow::anyhow!("model not loaded"))?;
2922        let needs_reload = loaded_ref.transformer.is_none();
2923        if needs_reload {
2924            let mut loaded_mut = self
2925                .base
2926                .loaded
2927                .take()
2928                .ok_or_else(|| anyhow::anyhow!("model not loaded"))?;
2929            progress.stage_start("Reloading Qwen-Image transformer");
2930            let reload_start = Instant::now();
2931            self.reload_transformer(&mut loaded_mut, req.width as usize, req.height as usize)?;
2932            progress.stage_done("Reloading Qwen-Image transformer", reload_start.elapsed());
2933            self.base.loaded = Some(loaded_mut);
2934        }
2935
2936        let loaded = self
2937            .base
2938            .loaded
2939            .as_mut()
2940            .ok_or_else(|| anyhow::anyhow!("model not loaded"))?;
2941        let seed = req.seed.unwrap_or_else(rand_seed);
2942
2943        let width = req.width as usize;
2944        let height = req.height as usize;
2945
2946        tracing::info!(
2947            prompt = %req.prompt,
2948            seed, width, height,
2949            steps = req.steps,
2950            "starting Qwen-Image generation"
2951        );
2952
2953        let use_cfg = req.guidance > 1.0;
2954        let prompt_key = prompt_text_key(&req.prompt);
2955        let uncond_key = prompt_text_key(QWEN_EMPTY_NEGATIVE_PROMPT);
2956        let prompt_cached = self
2957            .prompt_cache
2958            .lock()
2959            .expect("cache poisoned")
2960            .get_cloned(&prompt_key);
2961        let uncond_cached = if use_cfg {
2962            self.prompt_cache
2963                .lock()
2964                .expect("cache poisoned")
2965                .get_cloned(&uncond_key)
2966        } else {
2967            None
2968        };
2969        let both_cached = prompt_cached.is_some() && (!use_cfg || uncond_cached.is_some());
2970
2971        let (encoder_hidden_states, encoder_attention_mask, uncond_hs, uncond_mask) = if both_cached
2972        {
2973            let cached = prompt_cached.expect("prompt cache unexpectedly missing");
2974            progress.cache_hit("prompt conditioning");
2975            let (hs, mask) = cached.restore(&loaded.device, loaded.dtype)?;
2976            let (u_hs, u_mask) = if use_cfg {
2977                progress.cache_hit("unconditional conditioning");
2978                let ucached =
2979                    uncond_cached.expect("unconditional prompt cache unexpectedly missing");
2980                let (u_hs, u_mask) = ucached.restore(&loaded.device, loaded.dtype)?;
2981                (Some(u_hs), Some(u_mask))
2982            } else {
2983                (None, None)
2984            };
2985            (hs, mask, u_hs, u_mask)
2986        } else {
2987            if loaded.text_encoder.model.is_none() {
2988                let label = if loaded.text_encoder.is_parked() {
2989                    "Unparking Qwen2.5 encoder (CPU→GPU)"
2990                } else {
2991                    "Reloading Qwen2.5 encoder"
2992                };
2993                progress.stage_start(label);
2994                let reload_start = Instant::now();
2995                if loaded.text_encoder.is_parked() {
2996                    loaded.text_encoder.unpark_to_gpu(progress)?;
2997                } else {
2998                    loaded.text_encoder.reload(progress)?;
2999                }
3000                progress.stage_done(label, reload_start.elapsed());
3001            }
3002
3003            let (hs, mask) = Self::encode_prompt_cached(
3004                progress,
3005                &self.prompt_cache,
3006                &mut loaded.text_encoder,
3007                &req.prompt,
3008                &loaded.device,
3009                loaded.dtype,
3010            )?;
3011
3012            let (u_hs, u_mask) = if use_cfg {
3013                let (hs, mask) = Self::encode_prompt_cached(
3014                    progress,
3015                    &self.prompt_cache,
3016                    &mut loaded.text_encoder,
3017                    QWEN_EMPTY_NEGATIVE_PROMPT,
3018                    &loaded.device,
3019                    loaded.dtype,
3020                )?;
3021                (Some(hs), Some(mask))
3022            } else {
3023                (None, None)
3024            };
3025
3026            (hs, mask, u_hs, u_mask)
3027        };
3028
3029        let (encoder_hidden_states, encoder_attention_mask, uncond_hs, uncond_mask) = if use_cfg {
3030            let ((cond_hs, cond_mask), (neg_hs, neg_mask)) = align_cfg_conditioning(
3031                &encoder_hidden_states,
3032                &encoder_attention_mask,
3033                uncond_hs.as_ref().expect("unconditional prompt missing"),
3034                uncond_mask.as_ref().expect("unconditional mask missing"),
3035            )?;
3036            (cond_hs, cond_mask, Some(neg_hs), Some(neg_mask))
3037        } else {
3038            (
3039                encoder_hidden_states,
3040                encoder_attention_mask,
3041                uncond_hs,
3042                uncond_mask,
3043            )
3044        };
3045
3046        // Drop or park text encoder to free VRAM for denoising.
3047        if loaded.text_encoder.on_gpu {
3048            let free_after_encode = usable_free_vram_bytes(gpu_ordinal).unwrap_or(0);
3049            let required_for_residency = Self::qwen2_hot_text_encoder_required_vram(
3050                req.width,
3051                req.height,
3052                if req.guidance > 1.0 { 2 } else { 1 },
3053                loaded.dtype,
3054            );
3055            let action =
3056                Self::qwen2_text_encoder_post_encode_action(Qwen2TextEncoderResidencyInput {
3057                    on_gpu: loaded.text_encoder.on_gpu,
3058                    is_quantized: loaded.text_encoder.is_quantized,
3059                    is_metal: loaded.device.is_metal(),
3060                    keep_te_ram: crate::device::keep_te_in_ram(),
3061                    prompt_cache_miss: !both_cached,
3062                    transformer_resident: loaded.transformer.is_some(),
3063                    free_vram_bytes: free_after_encode,
3064                    required_vram_bytes: required_for_residency,
3065                });
3066            match action {
3067                Qwen2TextEncoderPostEncodeAction::KeepGpu => {
3068                    progress.info(&format!(
3069                        "Keeping Qwen2.5 text encoder on GPU for hot prompt-cache misses ({} free >= {} reserve)",
3070                        fmt_gb(free_after_encode),
3071                        fmt_gb(required_for_residency)
3072                    ));
3073                    tracing::info!(
3074                        free_vram_bytes = free_after_encode,
3075                        required_vram_bytes = required_for_residency,
3076                        is_quantized = loaded.text_encoder.is_quantized,
3077                        "Qwen2.5 text encoder kept on GPU after cache miss"
3078                    );
3079                }
3080                Qwen2TextEncoderPostEncodeAction::ParkCpu => {
3081                    loaded.text_encoder.park_to_cpu()?;
3082                    progress.info(&format!(
3083                        "Parked Qwen2.5 text encoder to CPU host RAM before denoise ({} free < {} reserve)",
3084                        fmt_gb(free_after_encode),
3085                        fmt_gb(required_for_residency)
3086                    ));
3087                    tracing::info!("Qwen2.5 text encoder parked to CPU host RAM");
3088                }
3089                Qwen2TextEncoderPostEncodeAction::Drop => {
3090                    loaded.text_encoder.drop_weights();
3091                    progress.info(&format!(
3092                        "Dropped Qwen2.5 text encoder before denoise ({} free < {} reserve or cache hit)",
3093                        fmt_gb(free_after_encode),
3094                        fmt_gb(required_for_residency)
3095                    ));
3096                    tracing::info!("Qwen2.5 text encoder dropped from GPU");
3097                }
3098            }
3099        }
3100
3101        // 3. Calculate latent dimensions
3102        let vae_downsample = 8;
3103        let latent_h = height / vae_downsample;
3104        let latent_w = width / vae_downsample;
3105        let is_img2img = req.source_image.is_some();
3106
3107        // For img2img, encode source image using the pre-loaded VAE
3108        let (prepared_img2img_latents, inpaint_ctx) =
3109            if let Some(ref source_bytes) = req.source_image {
3110                let encoded = Self::encode_vae_with_fallback(
3111                    source_bytes,
3112                    req.width,
3113                    req.height,
3114                    &loaded.vae,
3115                    &loaded.vae_device,
3116                    &loaded.device,
3117                    progress,
3118                    || {
3119                        Ok(QwenImageVae::load(
3120                            &loaded.vae_path,
3121                            &Device::Cpu,
3122                            DType::F32,
3123                            progress,
3124                        )?)
3125                    },
3126                )?;
3127                let encoded = encoded.to_device(&loaded.device)?.to_dtype(loaded.dtype)?;
3128                let start_sigma = QwenImageScheduler::new_img2img(
3129                    req.steps as usize,
3130                    image_seq_len(latent_h, latent_w, loaded.transformer_cfg.patch_size),
3131                    req.strength,
3132                )
3133                .0
3134                .initial_sigma();
3135                let prepared = crate::img2img::prepare_flow_match_img2img(
3136                    &encoded,
3137                    seed,
3138                    &[1, 16, latent_h, latent_w],
3139                    start_sigma,
3140                    req.mask_image.as_deref(),
3141                    latent_h,
3142                    latent_w,
3143                    &loaded.device,
3144                    loaded.dtype,
3145                )?;
3146
3147                (Some(prepared.initial_latents), prepared.inpaint_ctx)
3148            } else {
3149                (None, None)
3150            };
3151
3152        // 4. Initialize scheduler
3153        let image_seq_len = image_seq_len(latent_h, latent_w, loaded.transformer_cfg.patch_size);
3154        let (mut scheduler, num_steps) = if is_img2img {
3155            QwenImageScheduler::new_img2img(req.steps as usize, image_seq_len, req.strength)
3156        } else {
3157            let sched = QwenImageScheduler::new(req.steps as usize, image_seq_len);
3158            let n = sched.num_steps();
3159            (sched, n)
3160        };
3161
3162        // 5. Build initial latents
3163        let mut latents = if let Some(initial) = &prepared_img2img_latents {
3164            initial.clone()
3165        } else {
3166            let noise = crate::engine::seeded_randn(
3167                seed,
3168                &[1, 16, latent_h, latent_w],
3169                &loaded.device,
3170                loaded.dtype,
3171            )?;
3172            (noise * scheduler.initial_sigma())?
3173        };
3174
3175        // 7. Denoising loop
3176        let denoise_label = format!("Denoising ({} steps)", num_steps);
3177        progress.stage_start(&denoise_label);
3178        let denoise_start = Instant::now();
3179
3180        {
3181            let transformer = loaded
3182                .transformer
3183                .as_ref()
3184                .expect("transformer must be loaded for denoising");
3185
3186            let use_batched_cfg = use_cfg && transformer.supports_cfg_batching();
3187            if use_cfg && !use_batched_cfg {
3188                progress.info(
3189                    "Low-memory quantized Qwen CUDA path detected — disabling CFG batching to reduce peak CUDA memory",
3190                );
3191            }
3192
3193            // Pre-batch CFG inputs when the selected transformer path can handle
3194            // the extra batch dimension without exceeding peak memory.
3195            let (batched_hs, batched_mask) = if use_batched_cfg {
3196                let hs = Tensor::cat(&[&encoder_hidden_states, uncond_hs.as_ref().unwrap()], 0)?;
3197                let mask =
3198                    Tensor::cat(&[&encoder_attention_mask, uncond_mask.as_ref().unwrap()], 0)?;
3199                (hs, mask)
3200            } else {
3201                (
3202                    encoder_hidden_states.clone(),
3203                    encoder_attention_mask.clone(),
3204                )
3205            };
3206
3207            for step in 0..num_steps {
3208                let step_start = Instant::now();
3209                let t = scheduler.current_timestep();
3210                let noise_pred = if use_cfg {
3211                    let (cond_pred, uncond_pred) = if use_batched_cfg {
3212                        let t_tensor = Tensor::from_vec(vec![t as f32; 2], (2,), &loaded.device)?
3213                            .to_dtype(loaded.dtype)?;
3214                        let batched_latents = Tensor::cat(&[&latents, &latents], 0)?;
3215                        let batched_pred = transformer.forward(
3216                            &batched_latents,
3217                            &t_tensor,
3218                            &batched_hs,
3219                            &batched_mask,
3220                        )?;
3221                        (batched_pred.narrow(0, 0, 1)?, batched_pred.narrow(0, 1, 1)?)
3222                    } else {
3223                        let t_tensor = Tensor::from_vec(vec![t as f32], (1,), &loaded.device)?
3224                            .to_dtype(loaded.dtype)?;
3225                        (
3226                            transformer.forward(
3227                                &latents,
3228                                &t_tensor,
3229                                &encoder_hidden_states,
3230                                &encoder_attention_mask,
3231                            )?,
3232                            transformer.forward(
3233                                &latents,
3234                                &t_tensor,
3235                                uncond_hs.as_ref().unwrap(),
3236                                uncond_mask.as_ref().unwrap(),
3237                            )?,
3238                        )
3239                    };
3240                    // CFG in F32 + norm rescale (matches diffusers Qwen-Image pipeline)
3241                    let cond_f32 = cond_pred.to_dtype(DType::F32)?;
3242                    let uncond_f32 = uncond_pred.to_dtype(DType::F32)?;
3243                    let comb = (&uncond_f32 + ((&cond_f32 - &uncond_f32)? * req.guidance)?)?;
3244                    let cond_norm = cond_f32.sqr()?.sum_keepdim(1)?.sqrt()?;
3245                    let comb_norm = comb.sqr()?.sum_keepdim(1)?.sqrt()?.clamp(1e-8, f64::MAX)?;
3246                    let rescaled = comb.broadcast_mul(&(cond_norm / comb_norm)?)?;
3247                    rescaled.to_dtype(loaded.dtype)?
3248                } else {
3249                    let t_tensor = Tensor::from_vec(vec![t as f32], (1,), &loaded.device)?
3250                        .to_dtype(loaded.dtype)?;
3251                    transformer.forward(
3252                        &latents,
3253                        &t_tensor,
3254                        &encoder_hidden_states,
3255                        &encoder_attention_mask,
3256                    )?
3257                };
3258                if step == 0 || step == num_steps / 2 || step == num_steps - 1 {
3259                    Self::debug_tensor_stats(&format!("noise_pred[{step}]"), &noise_pred);
3260                    Self::debug_tensor_stats(&format!("latents[{step}]"), &latents);
3261                }
3262                if step == 0 {
3263                    Self::validate_qwen_tensor_boundary("noise_pred[0]", &noise_pred)?;
3264                }
3265                latents = scheduler.step(&noise_pred, &latents)?;
3266                if step == num_steps - 1 {
3267                    Self::validate_qwen_tensor_boundary("latents_final", &latents)?;
3268                }
3269
3270                // Inpainting: blend preserved regions back at current noise level
3271                if let Some(ref ctx) = inpaint_ctx {
3272                    latents = crate::img2img::apply_flow_match_inpaint(
3273                        &latents,
3274                        ctx,
3275                        scheduler.sigmas[step + 1],
3276                    )?;
3277                }
3278
3279                progress.emit(ProgressEvent::DenoiseStep {
3280                    step: step + 1,
3281                    total: num_steps,
3282                    elapsed: step_start.elapsed(),
3283                });
3284            }
3285        }
3286
3287        progress.stage_done(&denoise_label, denoise_start.elapsed());
3288
3289        // Free text embeddings
3290        drop(encoder_hidden_states);
3291        drop(encoder_attention_mask);
3292        drop(uncond_hs);
3293        drop(uncond_mask);
3294
3295        // 8. VAE decode
3296        progress.stage_start("VAE decode");
3297        let vae_start = Instant::now();
3298        let free_for_decode = usable_free_vram_bytes(self.base.gpu_ordinal).unwrap_or(0);
3299        let prefer_tiled = Self::should_proactively_tile_vae_decode(
3300            req.width,
3301            req.height,
3302            loaded.vae_device.is_cuda(),
3303            free_for_decode,
3304        );
3305
3306        // Always decode in F32 — matches sequential path and diffusers' force_upcast.
3307        let keep_transformer_hot = Self::can_keep_transformer_hot_for_vae(loaded);
3308        let image = if keep_transformer_hot {
3309            match Self::decode_vae_gpu_only(
3310                &latents,
3311                &loaded.vae,
3312                &loaded.vae_device,
3313                &loaded.device,
3314                progress,
3315                prefer_tiled,
3316            ) {
3317                Ok(image) => {
3318                    progress.info(
3319                        "Kept quantized Qwen transformer resident across VAE decode for faster hot-path reuse",
3320                    );
3321                    image
3322                }
3323                Err(err) if Self::is_oom_error(&err) => {
3324                    loaded.transformer = None;
3325                    loaded.device.synchronize()?;
3326                    progress.info(
3327                        "Dropping Qwen-Image transformer after resident VAE decode OOM and retrying",
3328                    );
3329                    Self::decode_vae_with_fallback(
3330                        &latents,
3331                        &loaded.vae,
3332                        &loaded.vae_device,
3333                        &loaded.device,
3334                        progress,
3335                        prefer_tiled,
3336                        || {
3337                            QwenImageVae::load(&loaded.vae_path, &Device::Cpu, DType::F32, progress)
3338                                .map_err(Into::into)
3339                        },
3340                    )?
3341                }
3342                Err(err) => return Err(err),
3343            }
3344        } else {
3345            loaded.transformer = None;
3346            loaded.device.synchronize()?;
3347            tracing::info!("Qwen-Image transformer dropped to free VRAM for VAE decode");
3348            Self::decode_vae_with_fallback(
3349                &latents,
3350                &loaded.vae,
3351                &loaded.vae_device,
3352                &loaded.device,
3353                progress,
3354                prefer_tiled,
3355                || {
3356                    QwenImageVae::load(&loaded.vae_path, &Device::Cpu, DType::F32, progress)
3357                        .map_err(Into::into)
3358                },
3359            )?
3360        };
3361        Self::validate_qwen_tensor_boundary("image_pre_postprocess", &image)?;
3362        Self::debug_tensor_stats("image_pre_postprocess", &image);
3363        let image = postprocess_image(&image)?;
3364        let post_stats = Self::validate_qwen_tensor_boundary("image_postprocess", &image)?;
3365        Self::debug_tensor_stats("image_postprocess", &image);
3366        let image = image.i(0)?;
3367        if Self::near_black_image_stats(post_stats) {
3368            progress.info(
3369                "Qwen diagnostic: decoded image is near-black after VAE postprocess; inspect MOLD_QWEN_DEBUG tensor stats to separate denoise math from VAE decode",
3370            );
3371            tracing::warn!(
3372                min = post_stats.min,
3373                max = post_stats.max,
3374                mean = post_stats.mean,
3375                "Qwen decoded image is near-black after VAE postprocess"
3376            );
3377        }
3378
3379        progress.stage_done("VAE decode", vae_start.elapsed());
3380
3381        // 9. Encode to output format
3382        let output_metadata = build_output_metadata(req, seed, None);
3383        let image_bytes = encode_image(
3384            &image,
3385            req.resolved_output_format(),
3386            req.width,
3387            req.height,
3388            output_metadata.as_ref(),
3389        )?;
3390
3391        let generation_time_ms = start.elapsed().as_millis() as u64;
3392        tracing::info!(generation_time_ms, seed, "Qwen-Image generation complete");
3393
3394        Ok(GenerateResponse {
3395            images: vec![ImageData {
3396                data: image_bytes,
3397                format: req.resolved_output_format(),
3398                width: req.width,
3399                height: req.height,
3400                index: 0,
3401            }],
3402            generation_time_ms,
3403            model: req.model.clone(),
3404            seed_used: seed,
3405            video: None,
3406            gpu: None,
3407        })
3408    }
3409}
3410
3411impl InferenceEngine for QwenImageEngine {
3412    fn generate(&mut self, req: &GenerateRequest) -> Result<GenerateResponse> {
3413        self.pending_placement = req.placement.clone();
3414        self.pending_loras = effective_loras(req);
3415        let result = self.generate_inner(req);
3416        self.pending_placement = None;
3417        self.pending_loras.clear();
3418        result
3419    }
3420
3421    fn model_name(&self) -> &str {
3422        self.base.model_name()
3423    }
3424
3425    fn is_loaded(&self) -> bool {
3426        self.base.is_loaded()
3427    }
3428
3429    fn load(&mut self) -> Result<()> {
3430        QwenImageEngine::load(self)
3431    }
3432
3433    fn unload(&mut self) {
3434        self.base.unload();
3435        clear_cache(&self.prompt_cache);
3436    }
3437
3438    fn set_on_progress(&mut self, callback: ProgressCallback) {
3439        self.base.set_on_progress(callback);
3440    }
3441
3442    fn clear_on_progress(&mut self) {
3443        self.base.clear_on_progress();
3444    }
3445
3446    fn model_paths(&self) -> Option<&mold_core::ModelPaths> {
3447        Some(&self.base.paths)
3448    }
3449}
3450
3451#[cfg(test)]
3452mod tests {
3453    use super::*;
3454    use crate::engine::LoadStrategy;
3455    use crate::shared_pool::SharedPool;
3456    use candle_core::Shape;
3457    use mold_core::ModelPaths;
3458    use safetensors::tensor::{serialize_to_file, Dtype as SafeDtype, TensorView};
3459    use std::collections::HashMap;
3460    use std::fs;
3461    use std::path::{Path, PathBuf};
3462    use std::sync::{Arc, Mutex};
3463    use std::time::{SystemTime, UNIX_EPOCH};
3464    use tokenizers::models::bpe::BPE;
3465
3466    fn temp_test_dir(prefix: &str) -> PathBuf {
3467        let suffix = SystemTime::now()
3468            .duration_since(UNIX_EPOCH)
3469            .unwrap()
3470            .as_nanos();
3471        let dir = std::env::temp_dir().join(format!("{prefix}-{}-{suffix}", std::process::id()));
3472        fs::create_dir_all(&dir).unwrap();
3473        dir
3474    }
3475
3476    fn touch(dir: &Path, name: &str) -> PathBuf {
3477        let path = dir.join(name);
3478        fs::write(&path, b"test").unwrap();
3479        path
3480    }
3481
3482    fn png_with_dimensions(width: u32, height: u32) -> Vec<u8> {
3483        let img = image::RgbImage::from_fn(width, height, |_, _| image::Rgb([255, 0, 0]));
3484        let mut buf = std::io::Cursor::new(Vec::new());
3485        image::DynamicImage::ImageRgb8(img)
3486            .write_to(&mut buf, image::ImageFormat::Png)
3487            .unwrap();
3488        buf.into_inner()
3489    }
3490
3491    fn qwen_image_model_paths(
3492        transformer: PathBuf,
3493        transformer_shards: Vec<PathBuf>,
3494        vae: PathBuf,
3495        text_tokenizer: Option<PathBuf>,
3496    ) -> ModelPaths {
3497        ModelPaths {
3498            transformer,
3499            transformer_shards,
3500            vae,
3501            spatial_upscaler: None,
3502            temporal_upscaler: None,
3503            distilled_lora: None,
3504            t5_encoder: None,
3505            clip_encoder: None,
3506            t5_tokenizer: None,
3507            clip_tokenizer: None,
3508            clip_encoder_2: None,
3509            clip_tokenizer_2: None,
3510            text_encoder_files: vec![],
3511            text_tokenizer,
3512            decoder: None,
3513        }
3514    }
3515
3516    fn resolved_text_encoder(is_gguf: bool, auto_use_gpu: bool) -> ResolvedQwen2TextEncoder {
3517        ResolvedQwen2TextEncoder {
3518            paths: vec![],
3519            vision_paths: vec![],
3520            is_gguf,
3521            variant_label: if is_gguf {
3522                "q6".to_string()
3523            } else {
3524                "bf16".to_string()
3525            },
3526            size_bytes: 0,
3527            auto_use_gpu,
3528        }
3529    }
3530
3531    fn tensor_values_u8(t: &Tensor) -> Vec<u8> {
3532        t.flatten_all()
3533            .unwrap()
3534            .to_vec1::<u8>()
3535            .expect("u8 tensor values")
3536    }
3537
3538    fn tensor_values_f32(t: &Tensor) -> Vec<f32> {
3539        t.flatten_all()
3540            .unwrap()
3541            .to_vec1::<f32>()
3542            .expect("f32 tensor values")
3543    }
3544
3545    #[test]
3546    fn safetensors_is_fp8_uses_filename_hint() {
3547        assert!(safetensors_is_fp8(Path::new(
3548            "/tmp/qwen-image-fp8.safetensors"
3549        )));
3550        assert!(!safetensors_is_fp8(Path::new(
3551            "/tmp/qwen-image.safetensors"
3552        )));
3553    }
3554
3555    #[test]
3556    fn text_encoder_is_fp8_uses_filename_hint() {
3557        assert!(text_encoder_is_fp8(&[PathBuf::from(
3558            "/tmp/qwen2-text-encoder-fp8-00001-of-00002.safetensors"
3559        )]));
3560        assert!(!text_encoder_is_fp8(&[PathBuf::from(
3561            "/tmp/qwen2-text-encoder-00001-of-00002.safetensors"
3562        )]));
3563    }
3564
3565    #[test]
3566    fn cached_prompt_conditioning_roundtrips_and_restores_mask() {
3567        let device = Device::Cpu;
3568        let hidden_states = Tensor::from_vec(
3569            vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0],
3570            Shape::from((1, 3, 2)),
3571            &device,
3572        )
3573        .unwrap();
3574        let cached = CachedPromptConditioning::from_parts(&hidden_states, 2).unwrap();
3575
3576        let (restored_hs, restored_mask) = cached.restore(&device, DType::F32).unwrap();
3577
3578        assert_eq!(
3579            tensor_values_f32(&restored_hs),
3580            tensor_values_f32(&hidden_states)
3581        );
3582        assert_eq!(tensor_values_u8(&restored_mask), vec![1, 1, 0]);
3583    }
3584
3585    #[test]
3586    fn pad_text_conditioning_keeps_original_when_target_matches() {
3587        let device = Device::Cpu;
3588        let hidden_states =
3589            Tensor::from_vec(vec![1.0f32, 2.0, 3.0, 4.0], Shape::from((1, 2, 2)), &device).unwrap();
3590        let mask = Tensor::from_vec(vec![1u8, 1], Shape::from((1, 2)), &device).unwrap();
3591
3592        let (padded_hs, padded_mask) = pad_text_conditioning(&hidden_states, &mask, 2).unwrap();
3593
3594        assert_eq!(
3595            tensor_values_f32(&padded_hs),
3596            tensor_values_f32(&hidden_states)
3597        );
3598        assert_eq!(tensor_values_u8(&padded_mask), vec![1, 1]);
3599    }
3600
3601    #[test]
3602    fn pad_text_conditioning_appends_zero_padding() {
3603        let device = Device::Cpu;
3604        let hidden_states =
3605            Tensor::from_vec(vec![1.0f32, 2.0, 3.0, 4.0], Shape::from((1, 2, 2)), &device).unwrap();
3606        let mask = Tensor::from_vec(vec![1u8, 0], Shape::from((1, 2)), &device).unwrap();
3607
3608        let (padded_hs, padded_mask) = pad_text_conditioning(&hidden_states, &mask, 4).unwrap();
3609
3610        assert_eq!(padded_hs.dims3().unwrap(), (1, 4, 2));
3611        assert_eq!(
3612            tensor_values_f32(&padded_hs),
3613            vec![1.0, 2.0, 3.0, 4.0, 0.0, 0.0, 0.0, 0.0]
3614        );
3615        assert_eq!(tensor_values_u8(&padded_mask), vec![1, 0, 0, 0]);
3616    }
3617
3618    #[test]
3619    fn pad_text_conditioning_rejects_shrinking() {
3620        let device = Device::Cpu;
3621        let hidden_states =
3622            Tensor::from_vec(vec![1.0f32, 2.0, 3.0, 4.0], Shape::from((1, 2, 2)), &device).unwrap();
3623        let mask = Tensor::from_vec(vec![1u8, 1], Shape::from((1, 2)), &device).unwrap();
3624
3625        let err = pad_text_conditioning(&hidden_states, &mask, 1).unwrap_err();
3626        assert!(err.to_string().contains("cannot shrink text conditioning"));
3627    }
3628
3629    #[test]
3630    fn align_cfg_conditioning_pads_shorter_branch_to_match_longer_one() {
3631        let device = Device::Cpu;
3632        let cond_hs = Tensor::from_vec(
3633            vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0],
3634            Shape::from((1, 3, 2)),
3635            &device,
3636        )
3637        .unwrap();
3638        let cond_mask = Tensor::from_vec(vec![1u8, 1, 1], Shape::from((1, 3)), &device).unwrap();
3639        let uncond_hs = Tensor::from_vec(
3640            vec![7.0f32, 8.0, 9.0, 10.0],
3641            Shape::from((1, 2, 2)),
3642            &device,
3643        )
3644        .unwrap();
3645        let uncond_mask = Tensor::from_vec(vec![1u8, 0], Shape::from((1, 2)), &device).unwrap();
3646
3647        let ((cond_hs, cond_mask), (uncond_hs, uncond_mask)) =
3648            align_cfg_conditioning(&cond_hs, &cond_mask, &uncond_hs, &uncond_mask).unwrap();
3649
3650        assert_eq!(cond_hs.dims3().unwrap(), (1, 3, 2));
3651        assert_eq!(uncond_hs.dims3().unwrap(), (1, 3, 2));
3652        assert_eq!(tensor_values_u8(&cond_mask), vec![1, 1, 1]);
3653        assert_eq!(tensor_values_u8(&uncond_mask), vec![1, 0, 0]);
3654        assert_eq!(
3655            tensor_values_f32(&uncond_hs),
3656            vec![7.0, 8.0, 9.0, 10.0, 0.0, 0.0]
3657        );
3658    }
3659
3660    #[test]
3661    fn qwen_image_detects_gguf_transformer() {
3662        let engine = QwenImageEngine::new(
3663            "qwen-image:q4".to_string(),
3664            ModelPaths {
3665                transformer: PathBuf::from("/tmp/qwen-image-Q4_K_S.gguf"),
3666                transformer_shards: vec![],
3667                vae: PathBuf::from("/tmp/vae.safetensors"),
3668                spatial_upscaler: None,
3669                temporal_upscaler: None,
3670                distilled_lora: None,
3671                t5_encoder: None,
3672                clip_encoder: None,
3673                t5_tokenizer: None,
3674                clip_tokenizer: None,
3675                clip_encoder_2: None,
3676                clip_tokenizer_2: None,
3677                text_encoder_files: vec![],
3678                text_tokenizer: Some(PathBuf::from("/tmp/tokenizer.json")),
3679                decoder: None,
3680            },
3681            LoadStrategy::Sequential,
3682            0,
3683            false,
3684            None,
3685        );
3686
3687        assert!(engine.detect_is_quantized());
3688    }
3689
3690    #[test]
3691    fn qwen_image_text_encoder_uses_gpu_on_metal() {
3692        let plan = QwenImageEngine::qwen2_text_encoder_plan_for_mode(
3693            Qwen2TextEncoderMode::Auto,
3694            false,
3695            true,
3696            &resolved_text_encoder(true, true),
3697        );
3698        assert!(plan.use_gpu);
3699        assert!(!plan.use_cpu_staging);
3700    }
3701
3702    #[test]
3703    fn qwen_image_text_encoder_uses_gpu_on_cuda_with_headroom() {
3704        let plan = QwenImageEngine::qwen2_text_encoder_plan_for_mode(
3705            Qwen2TextEncoderMode::Auto,
3706            true,
3707            false,
3708            &resolved_text_encoder(false, true),
3709        );
3710        assert!(plan.use_gpu);
3711        assert!(!plan.use_cpu_staging);
3712    }
3713
3714    #[test]
3715    fn qwen_image_text_encoder_uses_cpu_on_cuda_without_headroom() {
3716        let plan = QwenImageEngine::qwen2_text_encoder_plan_for_mode(
3717            Qwen2TextEncoderMode::Auto,
3718            true,
3719            false,
3720            &resolved_text_encoder(false, false),
3721        );
3722        assert!(!plan.use_gpu);
3723        assert!(!plan.use_cpu_staging);
3724    }
3725
3726    #[test]
3727    fn qwen_image_cpu_safetensors_text_encoder_stays_f32() {
3728        assert_eq!(
3729            QwenImageEngine::text_encoder_load_dtype(false, DType::BF16),
3730            DType::F32
3731        );
3732    }
3733
3734    #[test]
3735    fn qwen_image_cpu_gguf_text_encoder_stays_f32() {
3736        assert_eq!(
3737            QwenImageEngine::text_encoder_load_dtype(false, DType::BF16),
3738            DType::F32
3739        );
3740    }
3741
3742    #[test]
3743    fn qwen_image_text_encoder_gpu_override_disables_metal_staging() {
3744        let plan = QwenImageEngine::qwen2_text_encoder_plan_for_mode(
3745            Qwen2TextEncoderMode::Gpu,
3746            false,
3747            true,
3748            &resolved_text_encoder(true, true),
3749        );
3750        assert!(plan.use_gpu);
3751        assert!(!plan.use_cpu_staging);
3752    }
3753
3754    #[test]
3755    fn qwen_image_auto_prefers_q6_on_metal_with_headroom() {
3756        let q6 = mold_core::manifest::find_qwen2_vl_variant("q6").unwrap();
3757        let resolved = QwenImageEngine::choose_text_encoder_source(
3758            Some("auto"),
3759            false,
3760            true,
3761            qwen2_vram_threshold(q6.size_bytes) + 1,
3762            16_600_000_000,
3763            Qwen2TextEncoderUsage::Resident,
3764        )
3765        .unwrap();
3766        assert!(resolved.is_gguf);
3767        assert_eq!(resolved.variant_label, "q6");
3768        assert!(resolved.auto_use_gpu);
3769    }
3770
3771    #[test]
3772    fn qwen_image_auto_falls_back_to_q4_on_metal_when_q6_does_not_fit() {
3773        let q4 = mold_core::manifest::find_qwen2_vl_variant("q4").unwrap();
3774        let q6 = mold_core::manifest::find_qwen2_vl_variant("q6").unwrap();
3775        let free_vram = qwen2_vram_threshold(q4.size_bytes);
3776        assert!(free_vram < qwen2_vram_threshold(q6.size_bytes));
3777
3778        let resolved = QwenImageEngine::choose_text_encoder_source(
3779            Some("auto"),
3780            false,
3781            true,
3782            free_vram,
3783            0,
3784            Qwen2TextEncoderUsage::Resident,
3785        )
3786        .unwrap();
3787        assert!(resolved.is_gguf);
3788        assert_eq!(resolved.variant_label, "q4");
3789        assert!(resolved.auto_use_gpu);
3790    }
3791
3792    #[test]
3793    fn qwen_image_auto_keeps_bf16_default_on_cuda() {
3794        let resolved = QwenImageEngine::choose_text_encoder_source(
3795            Some("auto"),
3796            true,
3797            false,
3798            QWEN2_FP16_VRAM_THRESHOLD + 1,
3799            16_600_000_000,
3800            Qwen2TextEncoderUsage::Resident,
3801        )
3802        .unwrap();
3803        assert!(!resolved.is_gguf);
3804        assert_eq!(resolved.variant_label, "bf16");
3805        assert!(resolved.auto_use_gpu);
3806    }
3807
3808    #[test]
3809    fn qwen_image_auto_prefers_quantized_gpu_on_cuda_for_resident_mode_when_it_fits() {
3810        let resolved = QwenImageEngine::choose_text_encoder_source(
3811            Some("auto"),
3812            true,
3813            false,
3814            QWEN2_FP16_VRAM_THRESHOLD - 1,
3815            16_600_000_000,
3816            Qwen2TextEncoderUsage::Resident,
3817        )
3818        .unwrap();
3819        assert!(resolved.is_gguf);
3820        assert_eq!(resolved.variant_label, "q4");
3821        assert!(resolved.auto_use_gpu);
3822    }
3823
3824    #[test]
3825    fn qwen_image_auto_uses_quantized_cpu_fallback_on_cuda_for_resident_mode() {
3826        let resolved = QwenImageEngine::choose_text_encoder_source(
3827            Some("auto"),
3828            true,
3829            false,
3830            1,
3831            16_600_000_000,
3832            Qwen2TextEncoderUsage::Resident,
3833        )
3834        .unwrap();
3835        assert!(resolved.is_gguf);
3836        assert_eq!(resolved.variant_label, "q4");
3837        assert!(!resolved.auto_use_gpu);
3838    }
3839
3840    #[test]
3841    fn qwen_image_auto_prefers_quantized_gpu_on_cuda_for_sequential_mode_when_it_fits() {
3842        let resolved = QwenImageEngine::choose_text_encoder_source(
3843            Some("auto"),
3844            true,
3845            false,
3846            QWEN2_FP16_VRAM_THRESHOLD - 1,
3847            16_600_000_000,
3848            Qwen2TextEncoderUsage::Sequential,
3849        )
3850        .unwrap();
3851        assert!(resolved.is_gguf);
3852        assert_eq!(resolved.variant_label, "q4");
3853        assert!(resolved.auto_use_gpu);
3854    }
3855
3856    #[test]
3857    fn qwen_image_auto_uses_quantized_cpu_fallback_on_cuda_for_sequential_mode() {
3858        let resolved = QwenImageEngine::choose_text_encoder_source(
3859            Some("auto"),
3860            true,
3861            false,
3862            1,
3863            16_600_000_000,
3864            Qwen2TextEncoderUsage::Sequential,
3865        )
3866        .unwrap();
3867        assert!(resolved.is_gguf);
3868        assert_eq!(resolved.variant_label, "q4");
3869        assert!(!resolved.auto_use_gpu);
3870    }
3871
3872    #[test]
3873    fn qwen_image_explicit_q6_respects_cpu_fallback_on_cuda() {
3874        let resolved = QwenImageEngine::choose_text_encoder_source(
3875            Some("q6"),
3876            true,
3877            false,
3878            1,
3879            0,
3880            Qwen2TextEncoderUsage::Resident,
3881        )
3882        .unwrap();
3883        assert!(resolved.is_gguf);
3884        assert_eq!(resolved.variant_label, "q6");
3885        assert!(!resolved.auto_use_gpu);
3886    }
3887
3888    #[test]
3889    fn qwen_image_edit_accepts_quantized_text_with_bf16_vision_sidecar() {
3890        let dir = temp_test_dir("qwen-image-edit-text-encoder");
3891        let transformer = touch(&dir, "qwen-image-edit.gguf");
3892        let vae = touch(&dir, "vae.safetensors");
3893        let tokenizer = touch(&dir, "tokenizer.json");
3894        let mut paths = qwen_image_model_paths(transformer, vec![], vae, Some(tokenizer));
3895        paths.text_encoder_files = vec![touch(&dir, "text-encoder-00001-of-00004.safetensors")];
3896        let engine = QwenImageEngine::new(
3897            "qwen-image-edit-2511:q4".to_string(),
3898            paths,
3899            LoadStrategy::Sequential,
3900            0,
3901            false,
3902            None,
3903        );
3904
3905        let resolved = engine
3906            .resolve_text_encoder_source_with_preference(
3907                &Device::Cpu,
3908                0,
3909                Qwen2TextEncoderUsage::Sequential,
3910                Some("auto"),
3911            )
3912            .unwrap();
3913        assert!(!resolved.vision_paths.is_empty());
3914
3915        let resolved = engine
3916            .resolve_text_encoder_source_with_preference(
3917                &Device::Cpu,
3918                0,
3919                Qwen2TextEncoderUsage::Sequential,
3920                Some("q4"),
3921            )
3922            .unwrap();
3923        assert!(resolved.is_gguf);
3924        assert_eq!(resolved.variant_label, "q4");
3925        assert_eq!(resolved.vision_paths.len(), 1);
3926
3927        let resolved = engine
3928            .resolve_text_encoder_source_with_preference(
3929                &Device::Cpu,
3930                0,
3931                Qwen2TextEncoderUsage::Sequential,
3932                Some("bf16"),
3933            )
3934            .unwrap();
3935        assert!(!resolved.is_gguf);
3936        assert_eq!(resolved.variant_label, "bf16");
3937        assert_eq!(resolved.vision_paths.len(), 1);
3938    }
3939
3940    #[test]
3941    fn qwen_image_edit_prompt_numbers_each_picture_placeholder() {
3942        let prompt = QwenImageEngine::qwen_image_edit_prompt("swap materials", 3);
3943        assert!(prompt.contains(QWEN_IMAGE_EDIT_SYSTEM_PROMPT));
3944        assert!(prompt.contains("Picture 1: <|vision_start|><|image_pad|><|vision_end|>"));
3945        assert!(prompt.contains("Picture 2: <|vision_start|><|image_pad|><|vision_end|>"));
3946        assert!(prompt.contains("Picture 3: <|vision_start|><|image_pad|><|vision_end|>"));
3947        assert!(prompt.ends_with("<|im_start|>assistant\n"));
3948    }
3949
3950    #[test]
3951    fn qwen_image_edit_image_dims_fit_target_area_with_16px_alignment() {
3952        let bytes = png_with_dimensions(1600, 900);
3953        let (width, height) =
3954            QwenImageEngine::qwen_image_edit_image_dims(&bytes, QWEN_IMAGE_EDIT_VAE_AREA).unwrap();
3955        assert_eq!((width, height), (1360, 768));
3956        assert_eq!(width % 16, 0);
3957        assert_eq!(height % 16, 0);
3958    }
3959
3960    #[test]
3961    fn pack_and_unpack_latents_roundtrip() {
3962        let values: Vec<f32> = (0..(16 * 4 * 6)).map(|i| i as f32).collect();
3963        let latents = Tensor::from_vec(values.clone(), (1, 16, 4, 6), &Device::Cpu).unwrap();
3964        let packed = QwenImageEngine::pack_latents_4d(&latents).unwrap();
3965        assert_eq!(packed.dims3().unwrap(), (1, 6, 64));
3966
3967        let unpacked = QwenImageEngine::unpack_latents_packed(&packed, 4, 6).unwrap();
3968        assert_eq!(
3969            unpacked.flatten_all().unwrap().to_vec1::<f32>().unwrap(),
3970            values
3971        );
3972    }
3973
3974    #[test]
3975    fn quantized_cuda_cfg_headroom_scales_with_resolution() {
3976        let native = QwenImageEngine::quantized_cuda_cfg_headroom(1328, 1328);
3977        let reduced = QwenImageEngine::quantized_cuda_cfg_headroom(512, 512);
3978        assert_eq!(native, QWEN_GGUF_NATIVE_CFG_HEADROOM);
3979        assert_eq!(reduced, QWEN_GGUF_MIN_CFG_HEADROOM);
3980    }
3981
3982    #[test]
3983    fn qwen_quantized_native_resolution_uses_split_cfg_on_24gb_cuda() {
3984        assert!(QwenImageEngine::should_split_cfg_quantized_cuda(
3985            12_300_000_000,
3986            24_600_000_000,
3987            1328,
3988            1328,
3989        ));
3990    }
3991
3992    #[test]
3993    fn qwen_quantized_reduced_resolution_keeps_batched_cfg_when_it_fits() {
3994        assert!(!QwenImageEngine::should_split_cfg_quantized_cuda(
3995            12_300_000_000,
3996            24_600_000_000,
3997            512,
3998            512,
3999        ));
4000    }
4001
4002    #[test]
4003    fn qwen_quantized_cfg_split_boundary_does_not_split_when_estimate_exactly_fits() {
4004        let headroom = QwenImageEngine::quantized_cuda_cfg_headroom(1328, 1328);
4005        let transformer_size = 12_300_000_000;
4006        let free_vram = transformer_size + headroom;
4007        assert!(!QwenImageEngine::should_split_cfg_quantized_cuda(
4008            transformer_size,
4009            free_vram,
4010            1328,
4011            1328,
4012        ));
4013    }
4014
4015    #[test]
4016    fn qwen_quantized_unknown_vram_biases_to_split_cfg() {
4017        assert!(QwenImageEngine::should_split_cfg_quantized_cuda(
4018            12_300_000_000,
4019            0,
4020            1328,
4021            1328,
4022        ));
4023    }
4024
4025    #[test]
4026    fn qwen_is_oom_error_matches_cuda_memory_allocation_string() {
4027        assert!(QwenImageEngine::is_oom_error(&"cudaErrorMemoryAllocation"));
4028    }
4029
4030    #[test]
4031    fn qwen_debug_stats_counts_nan_and_inf() {
4032        let tensor = Tensor::from_vec(
4033            vec![0.0f32, 1.0, f32::NAN, f32::INFINITY, f32::NEG_INFINITY],
4034            Shape::from((5,)),
4035            &Device::Cpu,
4036        )
4037        .unwrap();
4038
4039        let stats = QwenImageEngine::tensor_stats(&tensor).unwrap();
4040
4041        assert_eq!(stats.total, 5);
4042        assert_eq!(stats.nan_count, 1);
4043        assert_eq!(stats.pos_inf_count, 1);
4044        assert_eq!(stats.neg_inf_count, 1);
4045        assert_eq!(stats.min, 0.0);
4046        assert_eq!(stats.max, 1.0);
4047        assert_eq!(stats.mean, 0.5);
4048    }
4049
4050    #[test]
4051    fn qwen_debug_stats_detects_near_black_postprocessed_image() {
4052        let stats = QwenTensorStats {
4053            min: 0.0,
4054            max: 0.01,
4055            mean: 0.004,
4056            nan_count: 0,
4057            pos_inf_count: 0,
4058            neg_inf_count: 0,
4059            total: 1024,
4060        };
4061
4062        assert!(QwenImageEngine::near_black_image_stats(stats));
4063    }
4064
4065    #[test]
4066    fn qwen_debug_stats_does_not_flag_non_black_image() {
4067        let stats = QwenTensorStats {
4068            min: 0.0,
4069            max: 0.75,
4070            mean: 0.18,
4071            nan_count: 0,
4072            pos_inf_count: 0,
4073            neg_inf_count: 0,
4074            total: 1024,
4075        };
4076
4077        assert!(!QwenImageEngine::near_black_image_stats(stats));
4078    }
4079
4080    #[test]
4081    fn qwen_debug_stats_formats_progress_message() {
4082        let stats = QwenTensorStats {
4083            min: 0.0,
4084            max: 1.0,
4085            mean: 0.5,
4086            nan_count: 2,
4087            pos_inf_count: 1,
4088            neg_inf_count: 1,
4089            total: 10,
4090        };
4091
4092        let message = QwenImageEngine::format_tensor_stats("sample", stats);
4093
4094        assert!(message.contains("NaN=2/10"));
4095        assert!(message.contains("+Inf=1"));
4096        assert!(message.contains("-Inf=1"));
4097    }
4098
4099    #[test]
4100    fn qwen_oom_fallback_returns_primary_success_without_running_fallback() {
4101        let mut progress = ProgressReporter::default();
4102        let messages = std::sync::Arc::new(std::sync::Mutex::new(Vec::<String>::new()));
4103        let messages_clone = messages.clone();
4104        progress.set_callback(Box::new(move |event| {
4105            if let ProgressEvent::Info { message } = event {
4106                messages_clone.lock().unwrap().push(message);
4107            }
4108        }));
4109
4110        let fallback_called = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false));
4111        let fallback_called_clone = fallback_called.clone();
4112        let value = QwenImageEngine::with_cuda_oom_cpu_fallback(
4113            || Ok(7usize),
4114            || {
4115                fallback_called_clone.store(true, std::sync::atomic::Ordering::SeqCst);
4116                Ok(9usize)
4117            },
4118            true,
4119            &Device::Cpu,
4120            &progress,
4121            "retrying",
4122            |_| true,
4123        )
4124        .unwrap();
4125
4126        assert_eq!(value, 7);
4127        assert!(!fallback_called.load(std::sync::atomic::Ordering::SeqCst));
4128        assert!(messages.lock().unwrap().is_empty());
4129    }
4130
4131    #[test]
4132    fn qwen_oom_fallback_retries_when_primary_ooms_on_cuda() {
4133        let mut progress = ProgressReporter::default();
4134        let messages = std::sync::Arc::new(std::sync::Mutex::new(Vec::<String>::new()));
4135        let messages_clone = messages.clone();
4136        progress.set_callback(Box::new(move |event| {
4137            if let ProgressEvent::Info { message } = event {
4138                messages_clone.lock().unwrap().push(message);
4139            }
4140        }));
4141
4142        let value = QwenImageEngine::with_cuda_oom_cpu_fallback(
4143            || Err(anyhow::anyhow!("cudaErrorMemoryAllocation")),
4144            || Ok(11usize),
4145            true,
4146            &Device::Cpu,
4147            &progress,
4148            "retrying",
4149            QwenImageEngine::is_oom_error,
4150        )
4151        .unwrap();
4152
4153        assert_eq!(value, 11);
4154        assert_eq!(messages.lock().unwrap().as_slice(), ["retrying"]);
4155    }
4156
4157    #[test]
4158    fn qwen_oom_fallback_does_not_retry_non_oom_errors() {
4159        let progress = ProgressReporter::default();
4160        let err = QwenImageEngine::with_cuda_oom_cpu_fallback(
4161            || Err(anyhow::anyhow!("not an oom")),
4162            || Ok(11usize),
4163            true,
4164            &Device::Cpu,
4165            &progress,
4166            "retrying",
4167            QwenImageEngine::is_oom_error,
4168        )
4169        .unwrap_err();
4170
4171        assert!(err.to_string().contains("not an oom"));
4172    }
4173
4174    #[test]
4175    fn qwen_tiled_fallback_returns_primary_success_without_retrying() {
4176        let progress = ProgressReporter::default();
4177        let tiled_called = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false));
4178        let cpu_called = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false));
4179        let tiled_called_clone = tiled_called.clone();
4180        let cpu_called_clone = cpu_called.clone();
4181
4182        let value = QwenImageEngine::with_cuda_tiled_then_cpu_fallback(
4183            || Ok(5usize),
4184            || {
4185                tiled_called_clone.store(true, std::sync::atomic::Ordering::SeqCst);
4186                Ok(7usize)
4187            },
4188            || {
4189                cpu_called_clone.store(true, std::sync::atomic::Ordering::SeqCst);
4190                Ok(9usize)
4191            },
4192            true,
4193            false,
4194            &Device::Cpu,
4195            &progress,
4196            "tiled",
4197            "cpu",
4198            |_| true,
4199        )
4200        .unwrap();
4201
4202        assert_eq!(value, 5);
4203        assert!(!tiled_called.load(std::sync::atomic::Ordering::SeqCst));
4204        assert!(!cpu_called.load(std::sync::atomic::Ordering::SeqCst));
4205    }
4206
4207    #[test]
4208    fn qwen_tiled_fallback_uses_tiled_result_before_cpu() {
4209        let mut progress = ProgressReporter::default();
4210        let messages = std::sync::Arc::new(std::sync::Mutex::new(Vec::<String>::new()));
4211        let messages_clone = messages.clone();
4212        progress.set_callback(Box::new(move |event| {
4213            if let ProgressEvent::Info { message } = event {
4214                messages_clone.lock().unwrap().push(message);
4215            }
4216        }));
4217
4218        let cpu_called = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false));
4219        let cpu_called_clone = cpu_called.clone();
4220        let value = QwenImageEngine::with_cuda_tiled_then_cpu_fallback(
4221            || Err(anyhow::anyhow!("out of memory")),
4222            || Ok(13usize),
4223            || {
4224                cpu_called_clone.store(true, std::sync::atomic::Ordering::SeqCst);
4225                Ok(17usize)
4226            },
4227            true,
4228            false,
4229            &Device::Cpu,
4230            &progress,
4231            "tiled",
4232            "cpu",
4233            QwenImageEngine::is_oom_error,
4234        )
4235        .unwrap();
4236
4237        assert_eq!(value, 13);
4238        assert!(!cpu_called.load(std::sync::atomic::Ordering::SeqCst));
4239        assert_eq!(messages.lock().unwrap().as_slice(), ["tiled"]);
4240    }
4241
4242    #[test]
4243    fn qwen_tiled_fallback_uses_cpu_after_tiled_oom() {
4244        let mut progress = ProgressReporter::default();
4245        let messages = std::sync::Arc::new(std::sync::Mutex::new(Vec::<String>::new()));
4246        let messages_clone = messages.clone();
4247        progress.set_callback(Box::new(move |event| {
4248            if let ProgressEvent::Info { message } = event {
4249                messages_clone.lock().unwrap().push(message);
4250            }
4251        }));
4252
4253        let value = QwenImageEngine::with_cuda_tiled_then_cpu_fallback(
4254            || Err(anyhow::anyhow!("OUT_OF_MEMORY")),
4255            || Err(anyhow::anyhow!("OUT_OF_MEMORY")),
4256            || Ok(19usize),
4257            true,
4258            false,
4259            &Device::Cpu,
4260            &progress,
4261            "tiled",
4262            "cpu",
4263            QwenImageEngine::is_oom_error,
4264        )
4265        .unwrap();
4266
4267        assert_eq!(value, 19);
4268        assert_eq!(messages.lock().unwrap().as_slice(), ["tiled", "cpu"]);
4269    }
4270
4271    #[test]
4272    fn qwen_tiled_fallback_propagates_non_oom_tiled_error() {
4273        let progress = ProgressReporter::default();
4274        let err = QwenImageEngine::with_cuda_tiled_then_cpu_fallback(
4275            || Err(anyhow::anyhow!("out of memory")),
4276            || Err(anyhow::anyhow!("bad tiled decode")),
4277            || Ok(19usize),
4278            true,
4279            false,
4280            &Device::Cpu,
4281            &progress,
4282            "tiled",
4283            "cpu",
4284            QwenImageEngine::is_oom_error,
4285        )
4286        .unwrap_err();
4287
4288        assert!(err.to_string().contains("bad tiled decode"));
4289    }
4290
4291    #[test]
4292    fn qwen_proactive_tiled_policy_selects_native_cuda_under_pressure() {
4293        assert!(QwenImageEngine::should_proactively_tile_vae_decode(
4294            1328,
4295            1328,
4296            true,
4297            6_000_000_000
4298        ));
4299        assert!(!QwenImageEngine::should_proactively_tile_vae_decode(
4300            512,
4301            512,
4302            true,
4303            6_000_000_000
4304        ));
4305        assert!(!QwenImageEngine::should_proactively_tile_vae_decode(
4306            1328,
4307            1328,
4308            false,
4309            6_000_000_000
4310        ));
4311        assert!(!QwenImageEngine::should_proactively_tile_vae_decode(
4312            1328,
4313            1328,
4314            true,
4315            16_000_000_000
4316        ));
4317    }
4318
4319    #[test]
4320    fn qwen_proactive_tiled_decode_skips_primary_full_decode() {
4321        let mut progress = ProgressReporter::default();
4322        let messages = std::sync::Arc::new(std::sync::Mutex::new(Vec::<String>::new()));
4323        let messages_clone = messages.clone();
4324        progress.set_callback(Box::new(move |event| {
4325            if let ProgressEvent::Info { message } = event {
4326                messages_clone.lock().unwrap().push(message);
4327            }
4328        }));
4329
4330        let primary_called = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false));
4331        let primary_called_clone = primary_called.clone();
4332        let value = QwenImageEngine::with_cuda_tiled_then_cpu_fallback(
4333            || {
4334                primary_called_clone.store(true, std::sync::atomic::Ordering::SeqCst);
4335                Ok(3usize)
4336            },
4337            || Ok(7usize),
4338            || Ok(9usize),
4339            true,
4340            true,
4341            &Device::Cpu,
4342            &progress,
4343            "tiled after oom",
4344            "cpu",
4345            QwenImageEngine::is_oom_error,
4346        )
4347        .unwrap();
4348
4349        assert_eq!(value, 7);
4350        assert!(!primary_called.load(std::sync::atomic::Ordering::SeqCst));
4351        assert_eq!(
4352            messages.lock().unwrap().as_slice(),
4353            ["Selecting tiled GPU VAE decode proactively"]
4354        );
4355    }
4356
4357    #[test]
4358    fn qwen_hot_text_encoder_keeps_gpu_after_cache_miss_with_headroom() {
4359        let action = QwenImageEngine::qwen2_text_encoder_post_encode_action(
4360            Qwen2TextEncoderResidencyInput {
4361                on_gpu: true,
4362                is_quantized: true,
4363                is_metal: false,
4364                keep_te_ram: false,
4365                prompt_cache_miss: true,
4366                transformer_resident: true,
4367                free_vram_bytes: 10_000_000_000,
4368                required_vram_bytes: 8_000_000_000,
4369            },
4370        );
4371
4372        assert_eq!(action, Qwen2TextEncoderPostEncodeAction::KeepGpu);
4373    }
4374
4375    #[test]
4376    fn qwen_hot_text_encoder_drops_after_cache_hit_even_with_headroom() {
4377        let action = QwenImageEngine::qwen2_text_encoder_post_encode_action(
4378            Qwen2TextEncoderResidencyInput {
4379                on_gpu: true,
4380                is_quantized: true,
4381                is_metal: false,
4382                keep_te_ram: false,
4383                prompt_cache_miss: false,
4384                transformer_resident: true,
4385                free_vram_bytes: 10_000_000_000,
4386                required_vram_bytes: 8_000_000_000,
4387            },
4388        );
4389
4390        assert_eq!(action, Qwen2TextEncoderPostEncodeAction::Drop);
4391    }
4392
4393    #[test]
4394    fn qwen_hot_text_encoder_drops_under_transformer_pressure() {
4395        let action = QwenImageEngine::qwen2_text_encoder_post_encode_action(
4396            Qwen2TextEncoderResidencyInput {
4397                on_gpu: true,
4398                is_quantized: true,
4399                is_metal: false,
4400                keep_te_ram: false,
4401                prompt_cache_miss: true,
4402                transformer_resident: true,
4403                free_vram_bytes: 7_999_999_999,
4404                required_vram_bytes: 8_000_000_000,
4405            },
4406        );
4407
4408        assert_eq!(action, Qwen2TextEncoderPostEncodeAction::Drop);
4409    }
4410
4411    #[test]
4412    fn qwen_hot_text_encoder_parks_bf16_when_keep_ram_enabled() {
4413        let action = QwenImageEngine::qwen2_text_encoder_post_encode_action(
4414            Qwen2TextEncoderResidencyInput {
4415                on_gpu: true,
4416                is_quantized: false,
4417                is_metal: false,
4418                keep_te_ram: true,
4419                prompt_cache_miss: true,
4420                transformer_resident: true,
4421                free_vram_bytes: 7_999_999_999,
4422                required_vram_bytes: 8_000_000_000,
4423            },
4424        );
4425
4426        assert_eq!(action, Qwen2TextEncoderPostEncodeAction::ParkCpu);
4427    }
4428
4429    #[test]
4430    fn qwen_hot_text_encoder_never_parks_quantized() {
4431        let action = QwenImageEngine::qwen2_text_encoder_post_encode_action(
4432            Qwen2TextEncoderResidencyInput {
4433                on_gpu: true,
4434                is_quantized: true,
4435                is_metal: false,
4436                keep_te_ram: true,
4437                prompt_cache_miss: true,
4438                transformer_resident: true,
4439                free_vram_bytes: 7_999_999_999,
4440                required_vram_bytes: 8_000_000_000,
4441            },
4442        );
4443
4444        assert_eq!(action, Qwen2TextEncoderPostEncodeAction::Drop);
4445    }
4446
4447    #[test]
4448    fn qwen_hot_text_encoder_drops_when_transformer_not_resident() {
4449        let action = QwenImageEngine::qwen2_text_encoder_post_encode_action(
4450            Qwen2TextEncoderResidencyInput {
4451                on_gpu: true,
4452                is_quantized: true,
4453                is_metal: false,
4454                keep_te_ram: false,
4455                prompt_cache_miss: true,
4456                transformer_resident: false,
4457                free_vram_bytes: 10_000_000_000,
4458                required_vram_bytes: 8_000_000_000,
4459            },
4460        );
4461
4462        assert_eq!(action, Qwen2TextEncoderPostEncodeAction::Drop);
4463    }
4464
4465    #[test]
4466    fn qwen_transformer_hot_vae_eligibility_requires_quantized_cuda_components() {
4467        assert!(QwenImageEngine::qwen_transformer_can_stay_hot_for_vae(
4468            true, true, true
4469        ));
4470        assert!(!QwenImageEngine::qwen_transformer_can_stay_hot_for_vae(
4471            false, true, true
4472        ));
4473        assert!(!QwenImageEngine::qwen_transformer_can_stay_hot_for_vae(
4474            true, false, true
4475        ));
4476        assert!(!QwenImageEngine::qwen_transformer_can_stay_hot_for_vae(
4477            true, true, false
4478        ));
4479    }
4480
4481    #[test]
4482    fn qwen_transformer_paths_prefer_shards_when_present() {
4483        let dir = temp_test_dir("mold-qwen-shards");
4484        let shard_a = touch(&dir, "transformer-00001-of-00002.safetensors");
4485        let shard_b = touch(&dir, "transformer-00002-of-00002.safetensors");
4486        let engine = QwenImageEngine::new(
4487            "qwen-image:q4".to_string(),
4488            qwen_image_model_paths(
4489                dir.join("transformer.safetensors"),
4490                vec![shard_a.clone(), shard_b.clone()],
4491                dir.join("vae.safetensors"),
4492                Some(dir.join("tokenizer.json")),
4493            ),
4494            LoadStrategy::Sequential,
4495            0,
4496            false,
4497            None,
4498        );
4499
4500        assert_eq!(engine.transformer_paths(), vec![shard_a, shard_b]);
4501
4502        fs::remove_dir_all(dir).ok();
4503    }
4504
4505    #[test]
4506    fn qwen_validate_paths_accepts_existing_files() {
4507        let dir = temp_test_dir("mold-qwen-validate-ok");
4508        let shard_a = touch(&dir, "transformer-00001-of-00002.safetensors");
4509        let shard_b = touch(&dir, "transformer-00002-of-00002.safetensors");
4510        let vae = touch(&dir, "vae.safetensors");
4511        let tokenizer = touch(&dir, "tokenizer.json");
4512        let gguf = touch(&dir, "transformer.gguf");
4513
4514        let sharded = QwenImageEngine::new(
4515            "qwen-image:bf16".to_string(),
4516            qwen_image_model_paths(
4517                dir.join("transformer.safetensors"),
4518                vec![shard_a, shard_b],
4519                vae.clone(),
4520                Some(tokenizer.clone()),
4521            ),
4522            LoadStrategy::Sequential,
4523            0,
4524            false,
4525            None,
4526        );
4527        assert_eq!(sharded.validate_paths().unwrap(), tokenizer);
4528        assert!(!sharded.detect_is_quantized());
4529
4530        let quantized = QwenImageEngine::new(
4531            "qwen-image:q4".to_string(),
4532            qwen_image_model_paths(gguf, vec![], vae, Some(dir.join("tokenizer.json"))),
4533            LoadStrategy::Sequential,
4534            0,
4535            false,
4536            None,
4537        );
4538        assert!(quantized.detect_is_quantized());
4539
4540        fs::remove_dir_all(dir).ok();
4541    }
4542
4543    #[test]
4544    fn qwen_validate_paths_requires_text_tokenizer() {
4545        let dir = temp_test_dir("mold-qwen-validate-missing");
4546        let engine = QwenImageEngine::new(
4547            "qwen-image:q4".to_string(),
4548            qwen_image_model_paths(
4549                dir.join("transformer.gguf"),
4550                vec![],
4551                dir.join("vae.safetensors"),
4552                None,
4553            ),
4554            LoadStrategy::Sequential,
4555            0,
4556            false,
4557            None,
4558        );
4559
4560        let err = engine.validate_paths().unwrap_err();
4561        assert!(err.to_string().contains("text tokenizer path required"));
4562
4563        fs::remove_dir_all(dir).ok();
4564    }
4565
4566    #[test]
4567    fn qwen_image_loads_text_tokenizer_through_shared_pool() {
4568        let dir = temp_test_dir("mold-qwen-tokenizer-pool");
4569        let tokenizer_path = dir.join("tokenizer.json");
4570        tokenizers::Tokenizer::new(BPE::default())
4571            .save(&tokenizer_path, false)
4572            .unwrap();
4573
4574        let shared_pool = Arc::new(Mutex::new(SharedPool::new()));
4575        let pooled = shared_pool
4576            .lock()
4577            .unwrap()
4578            .load_tokenizer(&tokenizer_path)
4579            .unwrap();
4580
4581        let engine = QwenImageEngine::new(
4582            "qwen-image:q4".to_string(),
4583            qwen_image_model_paths(
4584                dir.join("transformer.gguf"),
4585                vec![],
4586                dir.join("vae.safetensors"),
4587                Some(tokenizer_path.clone()),
4588            ),
4589            LoadStrategy::Sequential,
4590            0,
4591            false,
4592            Some(shared_pool),
4593        );
4594
4595        let loaded = engine.load_text_tokenizer(&tokenizer_path).unwrap();
4596
4597        assert!(Arc::ptr_eq(&pooled, &loaded));
4598        fs::remove_dir_all(dir).ok();
4599    }
4600
4601    #[test]
4602    fn qwen_image_loads_vae_tensors_through_shared_pool() {
4603        let dir = temp_test_dir("mold-qwen-vae-pool");
4604        let vae_path = dir.join("vae.safetensors");
4605        let weight = 1.0f32.to_le_bytes();
4606        let mut tensors = HashMap::new();
4607        tensors.insert(
4608            "encoder.conv_in.weight".to_string(),
4609            TensorView::new(SafeDtype::F32, vec![1], &weight).unwrap(),
4610        );
4611        serialize_to_file(&tensors, &None, &vae_path).unwrap();
4612
4613        let shared_pool = Arc::new(Mutex::new(SharedPool::new()));
4614        let pooled = shared_pool
4615            .lock()
4616            .unwrap()
4617            .load_safetensors_cpu_tensors(std::slice::from_ref(&vae_path))
4618            .unwrap()
4619            .unwrap();
4620
4621        let engine = QwenImageEngine::new(
4622            "qwen-image:q4".to_string(),
4623            qwen_image_model_paths(
4624                dir.join("transformer.gguf"),
4625                vec![],
4626                vae_path.clone(),
4627                Some(dir.join("tokenizer.json")),
4628            ),
4629            LoadStrategy::Sequential,
4630            0,
4631            false,
4632            Some(shared_pool),
4633        );
4634
4635        let loaded = engine.load_vae_cpu_tensors().unwrap().unwrap();
4636
4637        assert!(Arc::ptr_eq(&pooled, &loaded));
4638        fs::remove_dir_all(dir).ok();
4639    }
4640
4641    #[test]
4642    fn qwen_img2img_uses_minus_one_to_one_source_normalization() {
4643        assert_eq!(
4644            QwenImageEngine::img2img_source_normalize_range(),
4645            img_utils::NormalizeRange::MinusOneToOne
4646        );
4647    }
4648}