Skip to main content

mold_inference/zimage/
pipeline.rs

1use anyhow::{bail, Result};
2use candle_core::{DType, Device, IndexOp, Shape, Tensor};
3use candle_transformers::models::z_image::{
4    calculate_shift, postprocess_image, AutoEncoderKL, Config, FlowMatchEulerDiscreteScheduler,
5    SchedulerConfig, VaeConfig,
6};
7use candle_transformers::quantized_var_builder;
8use mold_core::{GenerateRequest, GenerateResponse, ImageData, LoraWeight, ModelPaths};
9use std::borrow::Cow;
10use std::collections::{BTreeMap, HashMap};
11use std::path::Path;
12use std::sync::{Arc, Mutex};
13use std::time::Instant;
14use tokenizers::Tokenizer;
15
16use super::gguf_dense::load_gguf_dense_transformer;
17use super::transformer::{MoldZImageTransformer2DModel, ZImageTransformer};
18use crate::cache::{
19    clear_cache, get_or_insert_cached_tensor, prompt_text_key, restore_cached_tensor, CachedTensor,
20    LruCache, DEFAULT_PROMPT_CACHE_CAPACITY,
21};
22use crate::device::{
23    check_memory_budget, effective_device_ref, fmt_gb, free_vram_bytes, memory_status_string,
24    preflight_memory_check, should_use_gpu, usable_free_vram_bytes,
25};
26// Re-exported for tests (test harness is disabled via `test = false` in Cargo.toml,
27// but tests reference this constant via `super::*`).
28#[cfg(test)]
29use crate::device::QWEN3_FP16_VRAM_THRESHOLD;
30use crate::encoders;
31use crate::engine::{rand_seed, InferenceEngine, LoadStrategy};
32use crate::engine_base::EngineBase;
33use crate::image::{build_output_metadata, encode_image};
34use crate::img_utils;
35use crate::progress::{ProgressCallback, ProgressEvent, ProgressReporter};
36
37/// Minimum free VRAM (bytes) required to place Z-Image VAE on GPU.
38/// The VAE itself is small (~160MB), but decode at 1024x1024 needs ~6GB workspace
39/// for conv2d im2col expansions through the upsampling blocks.
40const VAE_DECODE_VRAM_THRESHOLD: u64 = 6_500_000_000;
41/// Eager mode loads the VAE before denoising but drops the transformer before
42/// decode. Use a weight-load threshold here, not the decode workspace threshold,
43/// so CUDA/Metal decode still happens on GPU after the transformer is freed.
44const VAE_WEIGHT_LOAD_VRAM_THRESHOLD: u64 = 600_000_000;
45
46/// Z-Image scheduler shift constants from the reference implementation.
47const BASE_IMAGE_SEQ_LEN: usize = 256;
48const MAX_IMAGE_SEQ_LEN: usize = 4096;
49const ZIMAGE_SINGLE_FILE_PREFIX: &str = "model.diffusion_model.";
50
51struct ZImageSafetensorsBackend {
52    st: candle_core::safetensors::MmapedSafetensors,
53}
54
55impl ZImageSafetensorsBackend {
56    fn new(st: candle_core::safetensors::MmapedSafetensors) -> Self {
57        Self { st }
58    }
59
60    fn resolve_stored_name<'a>(&'a self, name: &'a str) -> Option<Cow<'a, str>> {
61        if self.st.get(name).is_ok() {
62            return Some(Cow::Borrowed(name));
63        }
64        if let Some(alias) = zimage_safetensors_alias(name) {
65            if self.st.get(alias.as_ref()).is_ok() {
66                return Some(alias);
67            }
68        }
69        let prefixed = format!("{ZIMAGE_SINGLE_FILE_PREFIX}{name}");
70        if self.st.get(&prefixed).is_ok() {
71            return Some(Cow::Owned(prefixed));
72        }
73        if let Some(alias) = zimage_safetensors_alias(name) {
74            let prefixed_alias = format!("{ZIMAGE_SINGLE_FILE_PREFIX}{}", alias.as_ref());
75            if self.st.get(&prefixed_alias).is_ok() {
76                return Some(Cow::Owned(prefixed_alias));
77            }
78        }
79        None
80    }
81
82    fn stored_name<'a>(&'a self, name: &'a str) -> Cow<'a, str> {
83        self.resolve_stored_name(name)
84            .unwrap_or(Cow::Borrowed(name))
85    }
86
87    fn load_cast(&self, name: &str, dtype: DType, dev: &Device) -> candle_core::Result<Tensor> {
88        let stored_name = self.stored_name(name);
89        let tensor = self.st.load(stored_name.as_ref(), dev)?;
90        if tensor.dtype() != dtype {
91            tensor.to_dtype(dtype)
92        } else {
93            Ok(tensor)
94        }
95    }
96
97    fn load_tensor(
98        &self,
99        name: &str,
100        expected_shape: Option<&Shape>,
101        dtype: DType,
102        dev: &Device,
103    ) -> candle_core::Result<Tensor> {
104        if let Some((source_name, component)) = zimage_qkv_request(name) {
105            return self.load_qkv_split(&source_name, component, expected_shape, dtype, dev);
106        }
107        self.load_cast(name, dtype, dev)
108    }
109
110    fn load_qkv_split(
111        &self,
112        source_name: &str,
113        component: usize,
114        expected_shape: Option<&Shape>,
115        dtype: DType,
116        dev: &Device,
117    ) -> candle_core::Result<Tensor> {
118        let qkv = self.load_cast(source_name, dtype, dev)?;
119        let rows = qkv.dim(0)?;
120        let split_rows = expected_shape
121            .and_then(|shape| shape.dims().first().copied())
122            .unwrap_or(rows / 3);
123        if component >= 3 || split_rows == 0 || rows != split_rows * 3 {
124            return Err(candle_core::Error::msg(format!(
125                "invalid fused QKV shape for {source_name}: rows={rows}, split_rows={split_rows}"
126            )));
127        }
128        qkv.narrow(0, component * split_rows, split_rows)?
129            .contiguous()
130    }
131}
132
133impl candle_nn::var_builder::SimpleBackend for ZImageSafetensorsBackend {
134    fn get(
135        &self,
136        shape: Shape,
137        name: &str,
138        _init: candle_nn::Init,
139        dtype: DType,
140        dev: &Device,
141    ) -> candle_core::Result<Tensor> {
142        let tensor = self.load_tensor(name, Some(&shape), dtype, dev)?;
143        if tensor.shape() != &shape {
144            Err(candle_core::Error::UnexpectedShape {
145                msg: format!("shape mismatch for {name}"),
146                expected: shape,
147                got: tensor.shape().clone(),
148            })?
149        }
150        Ok(tensor)
151    }
152
153    fn get_unchecked(&self, name: &str, dtype: DType, dev: &Device) -> candle_core::Result<Tensor> {
154        self.load_tensor(name, None, dtype, dev)
155    }
156
157    fn contains_tensor(&self, name: &str) -> bool {
158        if let Some((source_name, _)) = zimage_qkv_request(name) {
159            return self.resolve_stored_name(&source_name).is_some();
160        }
161        self.resolve_stored_name(name).is_some()
162    }
163}
164
165enum ZImageVaeTensorSource {
166    Mmap(candle_core::safetensors::MmapedSafetensors),
167    Cpu(Arc<HashMap<String, Tensor>>),
168}
169
170/// Z-Image VAE accepts both diffusers keys and LDM/Civitai aliases, plus a
171/// narrow `[out, in, 1, 1]` to `[out, in]` reshape. Keep both mmap and cached
172/// CPU tensor sources behind this backend so those rules cannot be bypassed.
173struct ZImageVaeSafetensorsBackend {
174    source: ZImageVaeTensorSource,
175    aliases: BTreeMap<String, String>,
176}
177
178impl ZImageVaeSafetensorsBackend {
179    fn new(st: candle_core::safetensors::MmapedSafetensors) -> Self {
180        let aliases = Self::aliases_from_names(st.tensors().into_iter().map(|(name, _)| name));
181        Self {
182            source: ZImageVaeTensorSource::Mmap(st),
183            aliases,
184        }
185    }
186
187    fn from_cpu_tensors(tensors: Arc<HashMap<String, Tensor>>) -> Self {
188        let aliases = Self::aliases_from_names(tensors.keys().cloned());
189        Self {
190            source: ZImageVaeTensorSource::Cpu(tensors),
191            aliases,
192        }
193    }
194
195    fn aliases_from_names(names: impl IntoIterator<Item = String>) -> BTreeMap<String, String> {
196        names
197            .into_iter()
198            .filter_map(|name| zimage_vae_diffusers_name(&name).map(|diffusers| (diffusers, name)))
199            .collect()
200    }
201
202    fn resolve_stored_name<'a>(&'a self, name: &'a str) -> Cow<'a, str> {
203        if self.contains_stored_tensor(name) {
204            return Cow::Borrowed(name);
205        }
206        self.aliases
207            .get(name)
208            .map(|source| Cow::Borrowed(source.as_str()))
209            .unwrap_or(Cow::Borrowed(name))
210    }
211
212    fn contains_stored_tensor(&self, name: &str) -> bool {
213        match &self.source {
214            ZImageVaeTensorSource::Mmap(st) => st.get(name).is_ok(),
215            ZImageVaeTensorSource::Cpu(tensors) => tensors.contains_key(name),
216        }
217    }
218
219    fn load_cast(&self, name: &str, dtype: DType, dev: &Device) -> candle_core::Result<Tensor> {
220        let stored_name = self.resolve_stored_name(name);
221        let tensor = match &self.source {
222            ZImageVaeTensorSource::Mmap(st) => st.load(stored_name.as_ref(), dev)?,
223            ZImageVaeTensorSource::Cpu(tensors) => tensors
224                .get(stored_name.as_ref())
225                .ok_or_else(|| {
226                    candle_core::Error::msg(format!(
227                        "missing Z-Image VAE tensor {}",
228                        stored_name.as_ref()
229                    ))
230                })?
231                .to_device(dev)?,
232        };
233        if tensor.dtype() != dtype {
234            tensor.to_dtype(dtype)
235        } else {
236            Ok(tensor)
237        }
238    }
239}
240
241impl candle_nn::var_builder::SimpleBackend for ZImageVaeSafetensorsBackend {
242    fn get(
243        &self,
244        shape: Shape,
245        name: &str,
246        _init: candle_nn::Init,
247        dtype: DType,
248        dev: &Device,
249    ) -> candle_core::Result<Tensor> {
250        let mut tensor = self.load_cast(name, dtype, dev)?;
251        if tensor.shape() != &shape
252            && tensor.dims().len() == 4
253            && shape.dims().len() == 2
254            && tensor.dims()[0] == shape.dims()[0]
255            && tensor.dims()[1] == shape.dims()[1]
256            && tensor.dims()[2] == 1
257            && tensor.dims()[3] == 1
258        {
259            tensor = tensor.reshape(shape.dims())?;
260        }
261        if tensor.shape() != &shape {
262            Err(candle_core::Error::UnexpectedShape {
263                msg: format!("shape mismatch for {name}"),
264                expected: shape,
265                got: tensor.shape().clone(),
266            })?
267        }
268        Ok(tensor)
269    }
270
271    fn get_unchecked(&self, name: &str, dtype: DType, dev: &Device) -> candle_core::Result<Tensor> {
272        self.load_cast(name, dtype, dev)
273    }
274
275    fn contains_tensor(&self, name: &str) -> bool {
276        self.contains_stored_tensor(name) || self.aliases.contains_key(name)
277    }
278}
279
280fn zimage_vae_diffusers_name(source_name: &str) -> Option<String> {
281    if source_name.starts_with("first_stage_model.") {
282        return crate::loader::vae_keys::apply_vae_rename(source_name);
283    }
284    if source_name.starts_with("encoder.")
285        || source_name.starts_with("decoder.")
286        || source_name.starts_with("quant_conv.")
287        || source_name.starts_with("post_quant_conv.")
288    {
289        return crate::loader::vae_keys::apply_vae_rename(&format!(
290            "first_stage_model.{source_name}"
291        ));
292    }
293    None
294}
295
296fn zimage_qkv_request(name: &str) -> Option<(String, usize)> {
297    for (suffix, component) in [
298        (".attention.to_q.weight", 0),
299        (".attention.to_k.weight", 1),
300        (".attention.to_v.weight", 2),
301    ] {
302        if let Some(prefix) = name.strip_suffix(suffix) {
303            return Some((format!("{prefix}.attention.qkv.weight"), component));
304        }
305    }
306    None
307}
308
309fn zimage_safetensors_alias(name: &str) -> Option<Cow<'_, str>> {
310    match name {
311        "all_x_embedder.2-1.weight" => return Some(Cow::Borrowed("x_embedder.weight")),
312        "all_x_embedder.2-1.bias" => return Some(Cow::Borrowed("x_embedder.bias")),
313        "all_final_layer.2-1.linear.weight" => {
314            return Some(Cow::Borrowed("final_layer.linear.weight"));
315        }
316        "all_final_layer.2-1.linear.bias" => {
317            return Some(Cow::Borrowed("final_layer.linear.bias"));
318        }
319        "all_final_layer.2-1.adaLN_modulation.1.weight" => {
320            return Some(Cow::Borrowed("final_layer.adaLN_modulation.1.weight"));
321        }
322        "all_final_layer.2-1.adaLN_modulation.1.bias" => {
323            return Some(Cow::Borrowed("final_layer.adaLN_modulation.1.bias"));
324        }
325        _ => {}
326    }
327    for (requested, stored) in [
328        (".attention.to_out.0.weight", ".attention.out.weight"),
329        (".attention.norm_q.weight", ".attention.q_norm.weight"),
330        (".attention.norm_k.weight", ".attention.k_norm.weight"),
331    ] {
332        if let Some(prefix) = name.strip_suffix(requested) {
333            return Some(Cow::Owned(format!("{prefix}{stored}")));
334        }
335    }
336    None
337}
338const BASE_SHIFT: f64 = 0.5;
339const MAX_SHIFT: f64 = 1.15;
340
341fn build_zimage_scheduler(
342    num_steps: usize,
343    image_seq_len: usize,
344    strength: Option<f64>,
345) -> (FlowMatchEulerDiscreteScheduler, usize) {
346    let mut scheduler = FlowMatchEulerDiscreteScheduler::new(SchedulerConfig::z_image_turbo());
347    let mu = calculate_shift(
348        image_seq_len,
349        BASE_IMAGE_SEQ_LEN,
350        MAX_IMAGE_SEQ_LEN,
351        BASE_SHIFT,
352        MAX_SHIFT,
353    );
354    let sigmas: Vec<f64> = (0..=num_steps)
355        .map(|v| v as f64 / num_steps as f64)
356        .rev()
357        .map(|t| {
358            if !(0.0..1.0).contains(&t) {
359                t
360            } else {
361                let e_mu = mu.exp();
362                e_mu / (e_mu + (1.0 / t - 1.0))
363            }
364        })
365        .collect();
366    scheduler.timesteps = sigmas[..sigmas.len().saturating_sub(1)]
367        .iter()
368        .map(|sigma| sigma * scheduler.config.num_train_timesteps as f64)
369        .collect();
370    scheduler.sigmas = sigmas;
371    let start_index = strength
372        .map(|strength| crate::img2img::img2img_start_index(num_steps, strength))
373        .unwrap_or(0);
374    if start_index > 0 {
375        scheduler.timesteps = scheduler.timesteps[start_index..].to_vec();
376        scheduler.sigmas = scheduler.sigmas[start_index..].to_vec();
377    }
378    scheduler.reset();
379    (scheduler, start_index)
380}
381
382fn load_zimage_vae(
383    path: &std::path::Path,
384    dtype: DType,
385    device: &Device,
386    progress: &ProgressReporter,
387    cached_tensors: Option<Arc<HashMap<String, Tensor>>>,
388) -> Result<AutoEncoderKL> {
389    use candle_core::safetensors::MmapedSafetensors;
390
391    let bytes_total = std::fs::metadata(path).map(|m| m.len()).unwrap_or(0);
392    progress.weight_load("VAE", 0, bytes_total);
393    let backend = if let Some(tensors) = cached_tensors {
394        ZImageVaeSafetensorsBackend::from_cpu_tensors(tensors)
395    } else {
396        let st = unsafe { MmapedSafetensors::multi(&[path])? };
397        ZImageVaeSafetensorsBackend::new(st)
398    };
399    let vae_vb = candle_nn::VarBuilder::from_backend(Box::new(backend), dtype, device.clone());
400    progress.weight_load("VAE", bytes_total, bytes_total);
401    AutoEncoderKL::new(&VaeConfig::z_image(), vae_vb).map_err(Into::into)
402}
403
404fn zimage_qwen3_preference<'a>(
405    configured: Option<&'a str>,
406    text_encoder_paths: &[std::path::PathBuf],
407) -> Option<&'a str> {
408    if configured.is_none() && zimage_has_recipe_text_encoder(text_encoder_paths) {
409        Some("bf16")
410    } else {
411        configured
412    }
413}
414
415fn zimage_has_recipe_text_encoder(text_encoder_paths: &[std::path::PathBuf]) -> bool {
416    text_encoder_paths.iter().any(|path| {
417        path.components()
418            .any(|component| component.as_os_str() == "civitai")
419    })
420}
421
422fn model_timestep(scheduler: &FlowMatchEulerDiscreteScheduler) -> f64 {
423    1.0 - scheduler.current_sigma()
424}
425
426fn zimage_debug_enabled() -> bool {
427    std::env::var_os("MOLD_ZIMAGE_DEBUG").is_some()
428}
429
430fn tensor_stats_summary(name: &str, tensor: &Tensor) -> Result<String> {
431    let flat = tensor.to_dtype(DType::F32)?.flatten_all()?;
432    let mean = flat.mean_all()?.to_scalar::<f32>()?;
433    let min = flat.min(0)?.to_scalar::<f32>()?;
434    let max = flat.max(0)?.to_scalar::<f32>()?;
435    let rms = flat.sqr()?.mean_all()?.to_scalar::<f32>()?.sqrt();
436    Ok(format!(
437        "{name}: mean={mean:.5} min={min:.5} max={max:.5} rms={rms:.5}"
438    ))
439}
440
441/// Loaded Z-Image model components, ready for inference.
442struct LoadedZImage {
443    /// Transformer is wrapped in Option so it can be dropped to free VRAM for VAE decode,
444    /// then reloaded from disk for the next generation (similar to FLUX's T5/CLIP offload).
445    transformer: Option<ZImageTransformer>,
446    text_encoder: encoders::qwen3::Qwen3Encoder,
447    vae: AutoEncoderKL,
448    transformer_cfg: Config,
449    /// GPU device for transformer + denoising
450    device: Device,
451    /// Device where the VAE lives (may be CPU if VRAM is extremely tight)
452    vae_device: Device,
453    dtype: DType,
454    /// Effective VAE dtype after `MOLD_VAE_DTYPE` resolution. May differ from
455    /// `dtype` when fp32 VAE decode is forced on GPU. Captured at load time.
456    /// CPU VAE is always F32 regardless of this field.
457    vae_dtype: DType,
458    /// Whether the transformer source file is GGUF (needed for reload/logging).
459    is_gguf: bool,
460    /// Path to the VAE safetensors file (needed for CPU fallback reload on OOM).
461    vae_path: std::path::PathBuf,
462}
463
464/// Z-Image inference engine backed by candle's z_image module.
465pub struct ZImageEngine {
466    base: EngineBase<LoadedZImage>,
467    /// Qwen3 variant preference: None/"auto" = VRAM-based, "bf16" = force BF16, "q8"/etc = specific.
468    qwen3_variant: Option<String>,
469    /// Force block-level transformer offload once the Z-Image runtime supports
470    /// streaming BF16 blocks. For now this is plumbed so requests are explicit
471    /// instead of being silently treated as ordinary eager loads.
472    offload: bool,
473    prompt_cache: Mutex<LruCache<String, CachedTensor>>,
474    /// Per-request placement override.
475    pending_placement: Option<mold_core::types::DevicePlacement>,
476    /// Per-request LoRA stack (effective: zero-scale entries already filtered).
477    ///
478    /// Set in [`InferenceEngine::generate`] at the entry point and cleared
479    /// at exit. `load_transformer` / `reload_transformer` consult this to
480    /// decide whether to wrap the constructed `VarBuilder` with a
481    /// `ZImageLoraBackend`.
482    pending_loras: Vec<LoraWeight>,
483    shared_pool: Option<Arc<Mutex<crate::shared_pool::SharedPool>>>,
484}
485
486/// Resolve the effective LoRA list for a request: `loras` (plural) wins
487/// over the singular `lora` when both are set, and zero-scale entries are
488/// dropped so they don't trigger a needless transformer rebuild.
489pub(crate) fn effective_zimage_loras(req: &GenerateRequest) -> Vec<LoraWeight> {
490    /// Threshold below which a LoRA scale is treated as off (matches FLUX).
491    const ZERO_SCALE_EPS: f64 = 1e-8;
492
493    let raw: Vec<LoraWeight> = if let Some(plural) = &req.loras {
494        if !plural.is_empty() {
495            plural.clone()
496        } else {
497            req.lora.iter().cloned().collect()
498        }
499    } else {
500        req.lora.iter().cloned().collect()
501    };
502    raw.into_iter()
503        .filter(|w| {
504            let keep = w.scale.abs() > ZERO_SCALE_EPS;
505            if !keep {
506                tracing::debug!(
507                    path = w.path.as_str(),
508                    scale = w.scale,
509                    "dropping zero-scale Z-Image LoRA"
510                );
511            }
512            keep
513        })
514        .collect()
515}
516
517#[derive(Debug, PartialEq, Eq)]
518enum ZImageOffloadDecision {
519    Disabled,
520    Selected,
521    Unsupported(&'static str),
522}
523
524fn zimage_offload_decision(
525    forced_offload: bool,
526    is_gguf: bool,
527    has_lora: bool,
528) -> ZImageOffloadDecision {
529    if !forced_offload {
530        return ZImageOffloadDecision::Disabled;
531    }
532    if is_gguf {
533        return ZImageOffloadDecision::Unsupported(
534            "Z-Image block-level offload is only planned for BF16/FP transformers; \
535             GGUF variants already use quantized/dense GGUF-specific paths",
536        );
537    }
538    if has_lora {
539        return ZImageOffloadDecision::Unsupported(
540            "Z-Image block-level offload with LoRA is not wired yet; \
541             LoRA merge/bypass semantics need a dedicated offload design",
542        );
543    }
544    ZImageOffloadDecision::Selected
545}
546
547impl ZImageEngine {
548    pub fn new(
549        model_name: String,
550        paths: ModelPaths,
551        qwen3_variant: Option<String>,
552        load_strategy: LoadStrategy,
553        gpu_ordinal: usize,
554        offload: bool,
555        shared_pool: Option<Arc<Mutex<crate::shared_pool::SharedPool>>>,
556    ) -> Self {
557        Self {
558            base: EngineBase::new(model_name, paths, load_strategy, gpu_ordinal),
559            qwen3_variant,
560            offload,
561            prompt_cache: Mutex::new(LruCache::new(DEFAULT_PROMPT_CACHE_CAPACITY)),
562            pending_placement: None,
563            pending_loras: Vec::new(),
564            shared_pool,
565        }
566    }
567
568    fn load_text_tokenizer(&self, tokenizer_path: &Path) -> Result<Arc<Tokenizer>> {
569        if let Some(shared_pool) = &self.shared_pool {
570            return shared_pool.lock().unwrap().load_tokenizer(tokenizer_path);
571        }
572        Tokenizer::from_file(tokenizer_path)
573            .map(Arc::new)
574            .map_err(|e| anyhow::anyhow!("failed to load Qwen3 tokenizer: {e}"))
575    }
576
577    fn encode_prompt_cached(
578        progress: &ProgressReporter,
579        prompt_cache: &Mutex<LruCache<String, CachedTensor>>,
580        encoder: &mut encoders::qwen3::Qwen3Encoder,
581        prompt: &str,
582        device: &Device,
583        dtype: DType,
584    ) -> Result<(Tensor, Tensor)> {
585        let cache_key = prompt_text_key(prompt);
586        let (cap_feats, cache_hit) =
587            get_or_insert_cached_tensor(prompt_cache, cache_key, device, dtype, || {
588                progress.stage_start("Encoding prompt (Qwen3)");
589                let encode_start = Instant::now();
590                let (cap_feats, _token_count) = encoder.encode(prompt, device, dtype)?;
591                progress.stage_done("Encoding prompt (Qwen3)", encode_start.elapsed());
592                Ok(cap_feats)
593            })?;
594        if cache_hit {
595            progress.cache_hit("prompt conditioning");
596        }
597        let token_count = cap_feats.dim(1)?;
598        let cap_mask = Tensor::ones((1, token_count), DType::U8, device)?;
599        Ok((cap_feats, cap_mask))
600    }
601
602    /// Resolve transformer shard paths: use `transformer_shards` if non-empty,
603    /// otherwise treat `transformer` as a single file.
604    fn transformer_paths(&self) -> Vec<std::path::PathBuf> {
605        if !self.base.paths.transformer_shards.is_empty() {
606            self.base.paths.transformer_shards.clone()
607        } else {
608            vec![self.base.paths.transformer.clone()]
609        }
610    }
611
612    /// Detect if the transformer is GGUF quantized.
613    fn detect_is_gguf(&self) -> bool {
614        self.base
615            .paths
616            .transformer
617            .extension()
618            .and_then(|e| e.to_str())
619            .map(|e| e.eq_ignore_ascii_case("gguf"))
620            .unwrap_or(false)
621    }
622
623    /// Validate tokenizer path and transformer/VAE paths exist.
624    fn validate_paths(&self) -> Result<std::path::PathBuf> {
625        let text_tokenizer_path =
626            self.base.paths.text_tokenizer.as_ref().ok_or_else(|| {
627                anyhow::anyhow!("text tokenizer path required for Z-Image models")
628            })?;
629        if !text_tokenizer_path.exists() {
630            bail!(
631                "text tokenizer file not found: {}",
632                text_tokenizer_path.display()
633            );
634        }
635
636        let xformer_paths = self.transformer_paths();
637        for path in &xformer_paths {
638            if !path.exists() {
639                bail!("transformer file not found: {}", path.display());
640            }
641        }
642        if !self.base.paths.vae.exists() {
643            bail!("VAE file not found: {}", self.base.paths.vae.display());
644        }
645
646        Ok(text_tokenizer_path.clone())
647    }
648
649    /// Load transformer from disk, optionally merging LoRA deltas inline.
650    ///
651    /// When `self.pending_loras` is non-empty:
652    /// * **BF16 path**: build an mmap-backed `SimpleBackend` ourselves
653    ///   (`weight_loader::load_safetensors_with_progress` returns a
654    ///   `VarBuilder` whose backend can't be wrapped after the fact),
655    ///   wrap it with a `ZImageLoraBackend`, and feed the resulting
656    ///   `VarBuilder` to `ZImageTransformer2DModel::new`.
657    /// * **GGUF path**: dequantise into a dense tensor map via
658    ///   [`super::gguf_dense::dequantize_gguf_dense_tensors`], wrap the
659    ///   `HashMap` `SimpleBackend` impl with the LoRA wrapper, then build
660    ///   the model. The fused-QKV `attention.qkv` LoRA targets the split
661    ///   `to_q`/`to_k`/`to_v` candle keys (the same key-space the BF16
662    ///   path sees) — see [`super::lora`] for the splat math.
663    fn load_transformer(
664        &self,
665        device: &Device,
666        dtype: DType,
667        cfg: &Config,
668    ) -> Result<ZImageTransformer> {
669        let is_gguf = self.detect_is_gguf();
670        let xformer_paths = self.transformer_paths();
671        let has_lora = !self.pending_loras.is_empty();
672
673        if is_gguf {
674            if has_lora {
675                let adapters =
676                    super::lora::load_lora_adapters(&self.pending_loras, &self.base.progress)?;
677                let specs: Vec<super::lora::ZImageLoraSpec<'_>> = adapters
678                    .iter()
679                    .zip(self.pending_loras.iter())
680                    .map(|(adapter, w)| super::lora::ZImageLoraSpec {
681                        adapter: adapter.as_ref(),
682                        scale: w.scale,
683                        path_hash: super::lora::lora_path_hash(&w.path),
684                    })
685                    .collect();
686                let vb = super::lora::gguf_lora_var_builder(
687                    &self.base.paths.transformer,
688                    &specs,
689                    device,
690                    &self.base.progress,
691                )?;
692                return Ok(ZImageTransformer::Quantized(Box::new(
693                    super::quantized_transformer::QuantizedZImageTransformer2DModel::new(
694                        cfg, dtype, vb,
695                    )?,
696                )));
697            }
698            let qvb =
699                quantized_var_builder::VarBuilder::from_gguf(&self.base.paths.transformer, device)?;
700            Ok(ZImageTransformer::Dense(Box::new(
701                load_gguf_dense_transformer(cfg, dtype, qvb)?,
702            )))
703        } else if has_lora {
704            // Build an mmap-backed SimpleBackend so we can layer a LoRA
705            // wrapper on top. Drops the streaming progress bar that
706            // `load_safetensors_with_progress` would emit; the LoRA
707            // info-line is good enough signal for a single load.
708            use candle_core::safetensors::MmapedSafetensors;
709            let path_refs: Vec<&std::path::Path> =
710                xformer_paths.iter().map(|p| p.as_path()).collect();
711            let st = unsafe { MmapedSafetensors::multi(&path_refs)? };
712            let inner: Box<dyn candle_nn::var_builder::SimpleBackend> =
713                Box::new(ZImageSafetensorsBackend::new(st));
714            let adapters =
715                super::lora::load_lora_adapters(&self.pending_loras, &self.base.progress)?;
716            let specs: Vec<super::lora::ZImageLoraSpec<'_>> = adapters
717                .iter()
718                .zip(self.pending_loras.iter())
719                .map(|(adapter, w)| super::lora::ZImageLoraSpec {
720                    adapter: adapter.as_ref(),
721                    scale: w.scale,
722                    path_hash: super::lora::lora_path_hash(&w.path),
723                })
724                .collect();
725            let wrapped =
726                super::lora::wrap_backend_with_lora(inner, &specs, &self.base.progress, None)?;
727            let vb = candle_nn::VarBuilder::from_backend(wrapped, dtype, device.clone());
728            Ok(ZImageTransformer::Dense(Box::new(
729                MoldZImageTransformer2DModel::new(cfg, vb)?,
730            )))
731        } else if self.offload {
732            use candle_core::safetensors::MmapedSafetensors;
733            let path_refs: Vec<&std::path::Path> =
734                xformer_paths.iter().map(|p| p.as_path()).collect();
735            let bytes_total: u64 = xformer_paths
736                .iter()
737                .map(|p| std::fs::metadata(p).map(|m| m.len()).unwrap_or(0))
738                .sum();
739            self.base
740                .progress
741                .weight_load("Z-Image transformer (offload stems)", 0, bytes_total);
742            let gpu_st = unsafe { MmapedSafetensors::multi(&path_refs)? };
743            let cpu_st = unsafe { MmapedSafetensors::multi(&path_refs)? };
744            let gpu_vb = candle_nn::VarBuilder::from_backend(
745                Box::new(ZImageSafetensorsBackend::new(gpu_st)),
746                dtype,
747                device.clone(),
748            );
749            let cpu_vb = candle_nn::VarBuilder::from_backend(
750                Box::new(ZImageSafetensorsBackend::new(cpu_st)),
751                dtype,
752                Device::Cpu,
753            );
754            self.base.progress.weight_load(
755                "Z-Image transformer (offload stems)",
756                bytes_total,
757                bytes_total,
758            );
759            Ok(ZImageTransformer::Offloaded(Box::new(
760                super::offload::OffloadedZImageTransformer::new(cfg, gpu_vb, cpu_vb)?,
761            )))
762        } else {
763            use candle_core::safetensors::MmapedSafetensors;
764            let path_refs: Vec<&std::path::Path> =
765                xformer_paths.iter().map(|p| p.as_path()).collect();
766            let bytes_total = xformer_paths
767                .iter()
768                .map(|p| std::fs::metadata(p).map(|m| m.len()).unwrap_or(0))
769                .sum();
770            self.base
771                .progress
772                .weight_load("Z-Image transformer", 0, bytes_total);
773            let st = unsafe { MmapedSafetensors::multi(&path_refs)? };
774            let xformer_vb = candle_nn::VarBuilder::from_backend(
775                Box::new(ZImageSafetensorsBackend::new(st)),
776                dtype,
777                device.clone(),
778            );
779            self.base
780                .progress
781                .weight_load("Z-Image transformer", bytes_total, bytes_total);
782            Ok(ZImageTransformer::Dense(Box::new(
783                MoldZImageTransformer2DModel::new(cfg, xformer_vb)?,
784            )))
785        }
786    }
787
788    /// Load VAE from disk.
789    fn load_vae(&self, device: &Device, dtype: DType) -> Result<AutoEncoderKL> {
790        let cached_tensors = self.load_vae_cpu_tensors()?;
791        load_zimage_vae(
792            self.base.paths.vae.as_path(),
793            dtype,
794            device,
795            &self.base.progress,
796            cached_tensors,
797        )
798    }
799
800    fn load_vae_cpu_tensors(&self) -> Result<Option<Arc<HashMap<String, Tensor>>>> {
801        let Some(shared_pool) = &self.shared_pool else {
802            return Ok(None);
803        };
804        shared_pool
805            .lock()
806            .unwrap()
807            .load_safetensors_cpu_tensors(std::slice::from_ref(&self.base.paths.vae))
808    }
809
810    /// Load all model components (Eager mode).
811    ///
812    /// On error, `self.base.loaded` remains `None` — all components are assembled into
813    /// local variables and only stored in `self.base.loaded` on success, so partial loads
814    /// cannot leave the engine in an inconsistent state.
815    pub fn load(&mut self) -> Result<()> {
816        if self.base.loaded.is_some() {
817            return Ok(());
818        }
819
820        // Sequential mode defers loading to generate_sequential()
821        if self.base.load_strategy == LoadStrategy::Sequential {
822            return Ok(());
823        }
824
825        tracing::info!(model = %self.base.model_name, "loading Z-Image model components...");
826
827        let is_gguf = self.detect_is_gguf();
828        let text_tokenizer_path = self.validate_paths()?;
829
830        let transformer_ref = effective_device_ref(
831            self.pending_placement.as_ref(),
832            |adv| Some(adv.transformer),
833            false,
834        );
835        let device = crate::device::resolve_device(Some(transformer_ref), || {
836            crate::device::create_device(self.base.gpu_ordinal, &self.base.progress)
837        })?;
838        let dtype = crate::engine::gpu_dtype(&device);
839        let transformer_cfg = Config::z_image_turbo();
840
841        // Load transformer
842        let xformer_label = if is_gguf {
843            "Loading Z-Image transformer (GPU, GGUF -> dense)".to_string()
844        } else {
845            let xformer_paths = self.transformer_paths();
846            format!(
847                "Loading Z-Image transformer ({} shards)",
848                xformer_paths.len()
849            )
850        };
851        self.base.progress.stage_start(&xformer_label);
852        let xformer_start = Instant::now();
853
854        let transformer = self.load_transformer(&device, dtype, &transformer_cfg)?;
855
856        self.base
857            .progress
858            .stage_done(&xformer_label, xformer_start.elapsed());
859        tracing::info!(quantized = is_gguf, "Z-Image transformer loaded");
860
861        // --- Decide where to place VAE and Qwen3 text encoder based on remaining VRAM ---
862        // Log the raw driver reading; budget the placement decisions
863        // against the reserve-adjusted value below.
864        let free_raw = free_vram_bytes(self.base.gpu_ordinal).unwrap_or(0);
865        let free = usable_free_vram_bytes(self.base.gpu_ordinal).unwrap_or(0);
866        let is_cuda = device.is_cuda();
867        let is_metal = device.is_metal();
868        if free_raw > 0 {
869            self.base.progress.info(&format!(
870                "Free VRAM after transformer: {}",
871                fmt_gb(free_raw)
872            ));
873            tracing::info!(
874                free_vram = free_raw,
875                free_vram_usable = free,
876                "free VRAM after loading transformer"
877            );
878        }
879
880        // Eager mode drops the transformer before decode, so placement only
881        // needs enough room for VAE weights at load time. Decode workspace is
882        // allocated later after the large transformer has been released.
883        let vae_on_gpu = should_use_gpu(is_cuda, is_metal, free, VAE_WEIGHT_LOAD_VRAM_THRESHOLD);
884        let vae_ref =
885            effective_device_ref(self.pending_placement.as_ref(), |adv| Some(adv.vae), false);
886        let vae_device = crate::device::resolve_device(Some(vae_ref), || {
887            Ok(if vae_on_gpu {
888                device.clone()
889            } else {
890                Device::Cpu
891            })
892        })?;
893        let vae_on_gpu = !vae_device.is_cpu();
894        // GPU branch honours `MOLD_VAE_DTYPE`; CPU is already F32 (the highest
895        // precision we'd ever want), so the env knob is a no-op there.
896        let vae_dtype = if vae_on_gpu {
897            crate::device::resolve_vae_dtype(dtype)
898        } else {
899            DType::F32
900        };
901        let vae_device_label = if vae_on_gpu { "GPU" } else { "CPU" };
902
903        if !vae_on_gpu && (is_cuda || is_metal) {
904            self.base.progress.info(&format!(
905                "VAE on CPU ({} free < {} threshold for VAE weight load)",
906                fmt_gb(free),
907                fmt_gb(VAE_WEIGHT_LOAD_VRAM_THRESHOLD),
908            ));
909        }
910
911        // Load VAE
912        let vae_label = format!("Loading VAE ({})", vae_device_label);
913        self.base.progress.stage_start(&vae_label);
914        let vae_start = Instant::now();
915        let vae = self.load_vae(&vae_device, vae_dtype)?;
916        self.base
917            .progress
918            .stage_done(&vae_label, vae_start.elapsed());
919        tracing::info!(device = vae_device_label, "Z-Image VAE loaded");
920
921        // --- Qwen3 text encoder: auto-select variant based on VRAM ---
922        self.base.progress.stage_start("Selecting Qwen3 encoder");
923        let qwen3_resolve_start = Instant::now();
924        let qwen3_preference = zimage_qwen3_preference(
925            self.qwen3_variant.as_deref(),
926            &self.base.paths.text_encoder_files,
927        );
928        let (resolved_paths, is_qwen3_gguf, te_on_gpu, _te_auto_device_label) = {
929            let bf16_paths = self.base.paths.text_encoder_files.clone();
930            let have_bf16 = !bf16_paths.is_empty() && bf16_paths.iter().all(|p| p.exists());
931            crate::encoders::variant_resolution::resolve_qwen3_variant(
932                &self.base.progress,
933                qwen3_preference,
934                &device,
935                free,
936                &bf16_paths,
937                have_bf16,
938                false,
939                crate::encoders::variant_resolution::Qwen3Size::B4,
940            )?
941        };
942        self.base
943            .progress
944            .stage_done("Selecting Qwen3 encoder", qwen3_resolve_start.elapsed());
945
946        let qwen3_ref = effective_device_ref(self.pending_placement.as_ref(), |adv| adv.qwen, true);
947        let auto_te_device = if te_on_gpu {
948            device.clone()
949        } else {
950            Device::Cpu
951        };
952        let te_device =
953            crate::device::resolve_device(Some(qwen3_ref), || Ok(auto_te_device.clone()))?;
954        let te_on_gpu = !te_device.is_cpu();
955        let te_device_label = if te_on_gpu { "GPU" } else { "CPU" };
956        let te_dtype = if te_on_gpu { dtype } else { DType::F32 };
957
958        // Load text encoder
959        let bf16_cfg = encoders::qwen3_bf16::Qwen3BF16Config::qwen3_4b();
960        let te_label = if is_qwen3_gguf {
961            format!("Loading Qwen3 text encoder (GGUF, {})", te_device_label)
962        } else {
963            format!(
964                "Loading Qwen3 text encoder ({} shards, {})",
965                resolved_paths.len(),
966                te_device_label,
967            )
968        };
969        self.base.progress.stage_start(&te_label);
970        let te_start = Instant::now();
971        let text_tokenizer = self.load_text_tokenizer(&text_tokenizer_path)?;
972
973        let text_encoder = if is_qwen3_gguf {
974            encoders::qwen3::Qwen3Encoder::load_gguf_with_tokenizer(
975                &resolved_paths[0],
976                &text_tokenizer_path,
977                Some(text_tokenizer),
978                &te_device,
979                &bf16_cfg,
980            )?
981        } else {
982            encoders::qwen3::Qwen3Encoder::load_bf16_with_tokenizer(
983                &resolved_paths,
984                &text_tokenizer_path,
985                Some(text_tokenizer),
986                &te_device,
987                te_dtype,
988                &bf16_cfg,
989                &self.base.progress,
990            )?
991        };
992
993        self.base.progress.stage_done(&te_label, te_start.elapsed());
994        tracing::info!(device = %te_device_label, quantized = is_qwen3_gguf, "Qwen3 text encoder loaded");
995
996        self.base.loaded = Some(LoadedZImage {
997            transformer: Some(transformer),
998            text_encoder,
999            vae,
1000            transformer_cfg,
1001            device,
1002            vae_device,
1003            dtype,
1004            vae_dtype,
1005            is_gguf,
1006            vae_path: self.base.paths.vae.clone(),
1007        });
1008
1009        tracing::info!(model = %self.base.model_name, "all Z-Image components loaded successfully");
1010        Ok(())
1011    }
1012
1013    /// Reload the transformer from disk (called when it was dropped to free VRAM for VAE decode).
1014    fn reload_transformer(&self, loaded: &mut LoadedZImage) -> Result<()> {
1015        let transformer =
1016            self.load_transformer(&loaded.device, loaded.dtype, &loaded.transformer_cfg)?;
1017        loaded.transformer = Some(transformer);
1018        Ok(())
1019    }
1020
1021    fn uses_sequential_generate_path(&self) -> bool {
1022        self.base.load_strategy == LoadStrategy::Sequential
1023            || self.offload
1024            || !self.pending_loras.is_empty()
1025    }
1026
1027    /// Generate an image using sequential loading strategy.
1028    ///
1029    /// Loads components one at a time and drops them when done:
1030    /// 1. Load Qwen3 → encode → drop Qwen3
1031    /// 2. Load transformer → denoise → drop transformer
1032    /// 3. Load VAE → decode → drop VAE
1033    ///
1034    /// Peak memory: max(Qwen3_size, transformer_size) instead of sum(all).
1035    fn generate_sequential(&mut self, req: &GenerateRequest) -> Result<GenerateResponse> {
1036        let text_tokenizer_path = self.validate_paths()?;
1037        let is_gguf = self.detect_is_gguf();
1038        let transformer_cfg = Config::z_image_turbo();
1039
1040        match zimage_offload_decision(self.offload, is_gguf, !self.pending_loras.is_empty()) {
1041            ZImageOffloadDecision::Disabled => {}
1042            ZImageOffloadDecision::Unsupported(reason) => bail!("{reason}"),
1043            ZImageOffloadDecision::Selected => {}
1044        }
1045
1046        // Check memory budget
1047        if let Some(warning) = check_memory_budget(&self.base.paths, LoadStrategy::Sequential) {
1048            self.base.progress.info(&warning);
1049        }
1050
1051        let transformer_ref = effective_device_ref(
1052            self.pending_placement.as_ref(),
1053            |adv| Some(adv.transformer),
1054            false,
1055        );
1056        let device = crate::device::resolve_device(Some(transformer_ref), || {
1057            crate::device::create_device(self.base.gpu_ordinal, &self.base.progress)
1058        })?;
1059        let dtype = crate::engine::gpu_dtype(&device);
1060
1061        let start = Instant::now();
1062        let seed = req.seed.unwrap_or_else(rand_seed);
1063
1064        let width = req.width as usize;
1065        let height = req.height as usize;
1066
1067        tracing::info!(
1068            prompt = %req.prompt,
1069            seed, width, height,
1070            steps = req.steps,
1071            "starting sequential Z-Image generation"
1072        );
1073
1074        self.base
1075            .progress
1076            .info("Using sequential loading (load-use-drop) to minimize peak memory");
1077
1078        // --- Phase 1: Qwen3 text encoding (check cache first to skip encoder load) ---
1079        let cache_key = prompt_text_key(&req.prompt);
1080        let (cap_feats, cap_mask) = if let Some(cap_feats) =
1081            restore_cached_tensor(&self.prompt_cache, &cache_key, &device, dtype)?
1082        {
1083            self.base.progress.cache_hit("prompt conditioning");
1084            let token_count = cap_feats.dim(1)?;
1085            let cap_mask = Tensor::ones((1, token_count), DType::U8, &device)?;
1086            (cap_feats, cap_mask)
1087        } else {
1088            // Reserve-adjusted reading drives the Qwen3 variant selection.
1089            let free = usable_free_vram_bytes(self.base.gpu_ordinal).unwrap_or(0);
1090            self.base.progress.stage_start("Selecting Qwen3 encoder");
1091            let qwen3_resolve_start = Instant::now();
1092            let qwen3_preference = zimage_qwen3_preference(
1093                self.qwen3_variant.as_deref(),
1094                &self.base.paths.text_encoder_files,
1095            );
1096            let (resolved_paths, is_qwen3_gguf, te_on_gpu, _te_auto_device_label) = {
1097                let bf16_paths = self.base.paths.text_encoder_files.clone();
1098                let have_bf16 = !bf16_paths.is_empty() && bf16_paths.iter().all(|p| p.exists());
1099                crate::encoders::variant_resolution::resolve_qwen3_variant(
1100                    &self.base.progress,
1101                    qwen3_preference,
1102                    &device,
1103                    free,
1104                    &bf16_paths,
1105                    have_bf16,
1106                    false,
1107                    crate::encoders::variant_resolution::Qwen3Size::B4,
1108                )?
1109            };
1110            self.base
1111                .progress
1112                .stage_done("Selecting Qwen3 encoder", qwen3_resolve_start.elapsed());
1113
1114            let qwen3_ref =
1115                effective_device_ref(self.pending_placement.as_ref(), |adv| adv.qwen, true);
1116            let auto_te_device = if te_on_gpu {
1117                device.clone()
1118            } else {
1119                Device::Cpu
1120            };
1121            let te_device =
1122                crate::device::resolve_device(Some(qwen3_ref), || Ok(auto_te_device.clone()))?;
1123            let te_on_gpu = !te_device.is_cpu();
1124            let te_device_label = if te_on_gpu { "GPU" } else { "CPU" };
1125            let te_dtype = if te_on_gpu { dtype } else { DType::F32 };
1126
1127            let bf16_cfg = encoders::qwen3_bf16::Qwen3BF16Config::qwen3_4b();
1128            let te_label = if is_qwen3_gguf {
1129                format!("Loading Qwen3 text encoder (GGUF, {})", te_device_label)
1130            } else {
1131                format!(
1132                    "Loading Qwen3 text encoder ({} shards, {})",
1133                    resolved_paths.len(),
1134                    te_device_label,
1135                )
1136            };
1137            let te_size: u64 = resolved_paths
1138                .iter()
1139                .filter_map(|p| std::fs::metadata(p).ok())
1140                .map(|m| m.len())
1141                .sum();
1142            let te_activation_budget = crate::device::activation_bytes(
1143                req.width,
1144                req.height,
1145                1,
1146                crate::device::dtype_bytes(te_dtype),
1147                crate::device::ActivationFamily::SmallTransformer,
1148            );
1149            preflight_memory_check("Qwen3 text encoder", te_size, te_activation_budget)?;
1150
1151            if let Some(status) = memory_status_string() {
1152                self.base.progress.info(&status);
1153            }
1154
1155            self.base.progress.stage_start(&te_label);
1156            let te_start = Instant::now();
1157            let text_tokenizer = self.load_text_tokenizer(&text_tokenizer_path)?;
1158
1159            let mut text_encoder = if is_qwen3_gguf {
1160                encoders::qwen3::Qwen3Encoder::load_gguf_with_tokenizer(
1161                    &resolved_paths[0],
1162                    &text_tokenizer_path,
1163                    Some(text_tokenizer),
1164                    &te_device,
1165                    &bf16_cfg,
1166                )?
1167            } else {
1168                encoders::qwen3::Qwen3Encoder::load_bf16_with_tokenizer(
1169                    &resolved_paths,
1170                    &text_tokenizer_path,
1171                    Some(text_tokenizer),
1172                    &te_device,
1173                    te_dtype,
1174                    &bf16_cfg,
1175                    &self.base.progress,
1176                )?
1177            };
1178            self.base.progress.stage_done(&te_label, te_start.elapsed());
1179
1180            let (cap_feats, cap_mask) = Self::encode_prompt_cached(
1181                &self.base.progress,
1182                &self.prompt_cache,
1183                &mut text_encoder,
1184                &req.prompt,
1185                &device,
1186                dtype,
1187            )?;
1188
1189            drop(text_encoder);
1190            self.base.progress.info("Freed Qwen3 text encoder");
1191            tracing::info!("Qwen3 text encoder dropped (sequential mode)");
1192
1193            (cap_feats, cap_mask)
1194        };
1195
1196        // Calculate latent dimensions up front so img2img can encode the source image
1197        // before the transformer is loaded. This keeps the encode path on GPU and
1198        // avoids the multi-minute CPU fallback.
1199        let vae_align = 16;
1200        let latent_h = 2 * (height / vae_align);
1201        let latent_w = 2 * (width / vae_align);
1202
1203        let patch_size = transformer_cfg.all_patch_size[0];
1204        let image_seq_len = (latent_h / patch_size) * (latent_w / patch_size);
1205        let (mut scheduler, start_index) = build_zimage_scheduler(
1206            req.steps as usize,
1207            image_seq_len,
1208            req.source_image.as_ref().map(|_| req.strength),
1209        );
1210
1211        if req.source_image.is_some() {
1212            tracing::info!(
1213                strength = req.strength,
1214                start_index,
1215                start_sigma = scheduler.sigmas[0],
1216                remaining_sigmas = scheduler.sigmas.len(),
1217                remaining_steps = scheduler.sigmas.len().saturating_sub(1),
1218                "img2img: truncated schedule from strength"
1219            );
1220        }
1221
1222        // --- Phase 2: Build initial latents ---
1223        let (mut latents, inpaint_ctx) = if let Some(ref source_bytes) = req.source_image {
1224            let start_sigma = scheduler.sigmas[0];
1225
1226            // Encode before loading the transformer so we can keep the VAE on GPU.
1227            let encode_vae_device = if device.is_cuda() || device.is_metal() {
1228                device.clone()
1229            } else {
1230                Device::Cpu
1231            };
1232            let encode_vae_dtype = if encode_vae_device.is_cpu() {
1233                DType::F32
1234            } else {
1235                crate::device::resolve_vae_dtype(dtype)
1236            };
1237            let encode_label = if encode_vae_device.is_cpu() {
1238                "Loading VAE for source encoding (CPU)"
1239            } else {
1240                "Loading VAE for source encoding (GPU)"
1241            };
1242
1243            self.base.progress.stage_start(encode_label);
1244            let vae_enc_start = Instant::now();
1245            let encode_vae = self.load_vae(&encode_vae_device, encode_vae_dtype)?;
1246            self.base
1247                .progress
1248                .stage_done(encode_label, vae_enc_start.elapsed());
1249
1250            self.base
1251                .progress
1252                .stage_start("Encoding source image (VAE)");
1253            let encode_start = Instant::now();
1254            let source_tensor = img_utils::decode_source_image(
1255                source_bytes,
1256                req.width,
1257                req.height,
1258                img_utils::NormalizeRange::ZeroToOne,
1259                &encode_vae_device,
1260                encode_vae_dtype,
1261            )?;
1262            let encoded = encode_vae.encode(&source_tensor)?;
1263            self.base
1264                .progress
1265                .stage_done("Encoding source image (VAE)", encode_start.elapsed());
1266
1267            // Drop encoding VAE before loading transformer
1268            drop(encode_vae);
1269
1270            // Generate noise on the target device
1271            let encoded = encoded.to_dtype(dtype)?.to_device(&device)?;
1272            let prepared = crate::img2img::prepare_flow_match_img2img(
1273                &encoded,
1274                seed,
1275                &[1, 16, latent_h, latent_w],
1276                start_sigma,
1277                req.mask_image.as_deref(),
1278                latent_h,
1279                latent_w,
1280                &device,
1281                dtype,
1282            )?;
1283            // Add frame dimension: (B, C, H, W) -> (B, C, 1, H, W)
1284            (prepared.initial_latents.unsqueeze(2)?, prepared.inpaint_ctx)
1285        } else {
1286            // txt2img: pure noise
1287            let noise =
1288                crate::engine::seeded_randn(seed, &[1, 16, latent_h, latent_w], &device, dtype)?;
1289            (noise.unsqueeze(2)?, None)
1290        };
1291
1292        // --- Phase 3: Load transformer and denoise ---
1293        let xformer_paths = self.transformer_paths();
1294        let xformer_size: u64 = xformer_paths
1295            .iter()
1296            .filter_map(|p| std::fs::metadata(p).ok())
1297            .map(|m| m.len())
1298            .sum();
1299        let xformer_activation_budget = crate::device::activation_bytes(
1300            req.width,
1301            req.height,
1302            1,
1303            crate::device::dtype_bytes(dtype),
1304            crate::device::ActivationFamily::ZImageDit,
1305        );
1306        preflight_memory_check(
1307            "Z-Image transformer",
1308            xformer_size,
1309            xformer_activation_budget,
1310        )?;
1311
1312        if let Some(status) = memory_status_string() {
1313            self.base.progress.info(&status);
1314        }
1315
1316        let xformer_label = if is_gguf {
1317            "Loading Z-Image transformer (GPU, GGUF -> dense)".to_string()
1318        } else {
1319            format!(
1320                "Loading Z-Image transformer ({} shards)",
1321                xformer_paths.len()
1322            )
1323        };
1324        self.base.progress.stage_start(&xformer_label);
1325        let xformer_start = Instant::now();
1326        let transformer = self.load_transformer(&device, dtype, &transformer_cfg)?;
1327        self.base
1328            .progress
1329            .stage_done(&xformer_label, xformer_start.elapsed());
1330
1331        let num_steps = scheduler.sigmas.len().saturating_sub(1);
1332        let denoise_label = format!("Denoising ({} steps)", num_steps);
1333        self.base.progress.stage_start(&denoise_label);
1334        let denoise_start = Instant::now();
1335
1336        for step in 0..num_steps {
1337            let step_start = Instant::now();
1338            let t = model_timestep(&scheduler);
1339            let t_tensor = Tensor::from_vec(vec![t as f32], (1,), &device)?.to_dtype(dtype)?;
1340            if zimage_debug_enabled() {
1341                tracing::debug!(
1342                    step = step + 1,
1343                    total = num_steps,
1344                    sigma = scheduler.current_sigma(),
1345                    timestep = t,
1346                    "{}",
1347                    tensor_stats_summary("latents_in", &latents)?
1348                );
1349            }
1350            let noise_pred = transformer.forward(&latents, &t_tensor, &cap_feats, &cap_mask)?;
1351            if zimage_debug_enabled() {
1352                tracing::debug!(
1353                    step = step + 1,
1354                    total = num_steps,
1355                    "{}",
1356                    tensor_stats_summary("noise_pred_raw", &noise_pred)?
1357                );
1358            }
1359            let noise_pred = noise_pred.neg()?;
1360            let noise_pred_4d = noise_pred.squeeze(2)?;
1361            let latents_4d = latents.squeeze(2)?;
1362            let prev_latents = scheduler.step(&noise_pred_4d, &latents_4d)?;
1363            latents = prev_latents.unsqueeze(2)?;
1364            if zimage_debug_enabled() {
1365                tracing::debug!(
1366                    step = step + 1,
1367                    total = num_steps,
1368                    sigma_next = scheduler.current_sigma(),
1369                    "{}",
1370                    tensor_stats_summary("latents_out", &latents)?
1371                );
1372            }
1373
1374            // Inpainting: blend preserved regions back at current noise level
1375            if let Some(ref ctx) = inpaint_ctx {
1376                let latents_4d = latents.squeeze(2)?;
1377                let blended = crate::img2img::apply_flow_match_inpaint(
1378                    &latents_4d,
1379                    ctx,
1380                    scheduler.sigmas[step + 1],
1381                )?;
1382                latents = blended.unsqueeze(2)?;
1383            }
1384
1385            self.base.progress.emit(ProgressEvent::DenoiseStep {
1386                step: step + 1,
1387                total: num_steps,
1388                elapsed: step_start.elapsed(),
1389            });
1390        }
1391
1392        self.base
1393            .progress
1394            .stage_done(&denoise_label, denoise_start.elapsed());
1395
1396        // Drop transformer and text embeddings to free memory for VAE decode
1397        drop(transformer);
1398        self.base.progress.info("Freed Z-Image transformer");
1399        drop(cap_feats);
1400        drop(cap_mask);
1401        device.synchronize()?;
1402        tracing::info!("Transformer dropped (sequential mode)");
1403
1404        // --- Phase 3: Load VAE and decode ---
1405        if let Some(status) = memory_status_string() {
1406            self.base.progress.info(&status);
1407        }
1408        // With sequential loading, we can always try GPU for VAE since transformer is freed.
1409        // Reserve-adjusted reading: should_use_gpu must respect the OS reserve.
1410        let free_for_vae = usable_free_vram_bytes(self.base.gpu_ordinal).unwrap_or(0);
1411        let vae_on_gpu = should_use_gpu(
1412            device.is_cuda(),
1413            device.is_metal(),
1414            free_for_vae,
1415            VAE_DECODE_VRAM_THRESHOLD,
1416        );
1417        let vae_ref =
1418            effective_device_ref(self.pending_placement.as_ref(), |adv| Some(adv.vae), false);
1419        let vae_device = crate::device::resolve_device(Some(vae_ref), || {
1420            Ok(if vae_on_gpu {
1421                device.clone()
1422            } else {
1423                Device::Cpu
1424            })
1425        })?;
1426        let vae_on_gpu = !vae_device.is_cpu();
1427        // GPU branch honours `MOLD_VAE_DTYPE`; CPU is already F32 (the highest
1428        // precision we'd ever want), so the env knob is a no-op there.
1429        let vae_dtype = if vae_on_gpu {
1430            crate::device::resolve_vae_dtype(dtype)
1431        } else {
1432            DType::F32
1433        };
1434        let vae_device_label = if vae_on_gpu { "GPU" } else { "CPU" };
1435
1436        let vae_label = format!("Loading VAE ({})", vae_device_label);
1437        self.base.progress.stage_start(&vae_label);
1438        let vae_start = Instant::now();
1439        let vae = self.load_vae(&vae_device, vae_dtype)?;
1440        self.base
1441            .progress
1442            .stage_done(&vae_label, vae_start.elapsed());
1443
1444        self.base.progress.stage_start("VAE decode");
1445        let vae_decode_start = Instant::now();
1446
1447        let latents = latents
1448            .squeeze(2)?
1449            .to_device(&vae_device)?
1450            .to_dtype(vae_dtype)?;
1451        let image = vae.decode(&latents)?;
1452        let image = postprocess_image(&image)?;
1453        let image = image.i(0)?;
1454
1455        self.base
1456            .progress
1457            .stage_done("VAE decode", vae_decode_start.elapsed());
1458
1459        // VAE dropped here
1460        let output_metadata = build_output_metadata(req, seed, None);
1461        let image_bytes = encode_image(
1462            &image,
1463            req.resolved_output_format(),
1464            req.width,
1465            req.height,
1466            output_metadata.as_ref(),
1467        )?;
1468
1469        let generation_time_ms = start.elapsed().as_millis() as u64;
1470        tracing::info!(
1471            generation_time_ms,
1472            seed,
1473            "sequential Z-Image generation complete"
1474        );
1475
1476        Ok(GenerateResponse {
1477            images: vec![ImageData {
1478                data: image_bytes,
1479                format: req.resolved_output_format(),
1480                width: req.width,
1481                height: req.height,
1482                index: 0,
1483            }],
1484            generation_time_ms,
1485            model: req.model.clone(),
1486            seed_used: seed,
1487            video: None,
1488            gpu: None,
1489        })
1490    }
1491}
1492
1493impl ZImageEngine {
1494    fn generate_inner(&mut self, req: &GenerateRequest) -> Result<GenerateResponse> {
1495        if req.scheduler.is_some() {
1496            tracing::warn!(
1497                "scheduler selection not supported for Z-Image (flow-matching), ignoring"
1498            );
1499        }
1500        // Sequential mode: load-use-drop each component. LoRA requests also
1501        // use this path because LoRA-merged transformer construction has
1502        // higher transient memory and must not run while eager-mode VAE/text
1503        // encoders are still GPU-resident.
1504        if self.uses_sequential_generate_path() {
1505            self.base.unload();
1506            return self.generate_sequential(req);
1507        }
1508
1509        // Eager mode: use pre-loaded components
1510        if self.base.loaded.is_none() {
1511            self.load()?;
1512        }
1513        if self.base.loaded.is_none() {
1514            bail!("model not loaded — call load() first");
1515        }
1516
1517        // Borrow progress reporter separately from loaded state.
1518        let progress = &self.base.progress;
1519
1520        let start = Instant::now();
1521
1522        // Reload transformer if it was dropped (offloaded) after previous VAE decode
1523        let loaded_ref = self
1524            .base
1525            .loaded
1526            .as_ref()
1527            .ok_or_else(|| anyhow::anyhow!("model not loaded — call load() first"))?;
1528        let needs_reload = loaded_ref.transformer.is_none();
1529        if needs_reload {
1530            {
1531                let mut loaded_mut = self
1532                    .base
1533                    .loaded
1534                    .take()
1535                    .ok_or_else(|| anyhow::anyhow!("model not loaded — call load() first"))?;
1536                let xformer_label = if loaded_mut.is_gguf {
1537                    "Reloading Z-Image transformer (GPU, GGUF -> dense)"
1538                } else {
1539                    "Reloading Z-Image transformer (GPU, BF16)"
1540                };
1541                progress.stage_start(xformer_label);
1542                let reload_start = Instant::now();
1543                self.reload_transformer(&mut loaded_mut)?;
1544                progress.stage_done(xformer_label, reload_start.elapsed());
1545                self.base.loaded = Some(loaded_mut);
1546            }
1547        }
1548
1549        let loaded = self
1550            .base
1551            .loaded
1552            .as_mut()
1553            .ok_or_else(|| anyhow::anyhow!("model not loaded — call load() first"))?;
1554        let seed = req.seed.unwrap_or_else(rand_seed);
1555
1556        let width = req.width as usize;
1557        let height = req.height as usize;
1558
1559        tracing::info!(
1560            prompt = %req.prompt,
1561            seed, width, height,
1562            steps = req.steps,
1563            "starting Z-Image generation"
1564        );
1565
1566        // 1. Encode prompt with Qwen3 (check cache first to avoid unnecessary reload)
1567        let cache_key = prompt_text_key(&req.prompt);
1568        let (cap_feats, cap_mask) = if let Some(cap_feats) =
1569            restore_cached_tensor(&self.prompt_cache, &cache_key, &loaded.device, loaded.dtype)?
1570        {
1571            progress.cache_hit("prompt conditioning");
1572            let token_count = cap_feats.dim(1)?;
1573            let cap_mask = Tensor::ones((1, token_count), DType::U8, &loaded.device)?;
1574            (cap_feats, cap_mask)
1575        } else {
1576            // Cache miss — restore encoder if it was dropped or parked after
1577            // a previous generation. is_parked() is true only on the BF16
1578            // path; GGUF flows through the reload() branch.
1579            if loaded.text_encoder.model.is_none() {
1580                let te_label = if loaded.text_encoder.is_parked() {
1581                    "Unparking Qwen3 encoder (CPU→GPU)"
1582                } else if loaded.text_encoder.is_quantized {
1583                    "Reloading Qwen3 encoder (GGUF)"
1584                } else {
1585                    "Reloading Qwen3 encoder (BF16)"
1586                };
1587                progress.stage_start(te_label);
1588                let reload_start = Instant::now();
1589                if loaded.text_encoder.is_parked() {
1590                    loaded.text_encoder.unpark_to_gpu(progress)?;
1591                } else {
1592                    loaded.text_encoder.reload(progress)?;
1593                }
1594                progress.stage_done(te_label, reload_start.elapsed());
1595            }
1596
1597            let (cap_feats, cap_mask) = Self::encode_prompt_cached(
1598                progress,
1599                &self.prompt_cache,
1600                &mut loaded.text_encoder,
1601                &req.prompt,
1602                &loaded.device,
1603                loaded.dtype,
1604            )?;
1605            tracing::info!(token_count = cap_feats.dim(1)?, "text encoding complete");
1606
1607            // Free GPU VRAM for denoising + VAE decode. With
1608            // `MOLD_KEEP_TE_RAM=1` and the BF16 encoder, parameters move
1609            // to host RAM instead of being released — saves ~10 s of reload
1610            // on the next request. GGUF and Metal flow through the original
1611            // drop path (Metal is unified memory, GGUF is device-tied).
1612            if loaded.text_encoder.on_gpu || loaded.device.is_metal() {
1613                let park_mode = crate::device::keep_te_in_ram()
1614                    && !loaded.device.is_metal()
1615                    && !loaded.text_encoder.is_quantized;
1616                if park_mode {
1617                    loaded.text_encoder.park_to_cpu()?;
1618                    tracing::info!(
1619                        on_gpu = loaded.text_encoder.on_gpu,
1620                        "Qwen3 text encoder parked to CPU host RAM"
1621                    );
1622                } else {
1623                    loaded.text_encoder.drop_weights();
1624                    tracing::info!(
1625                        on_gpu = loaded.text_encoder.on_gpu,
1626                        "Qwen3 text encoder dropped to free memory for denoising"
1627                    );
1628                }
1629            }
1630
1631            (cap_feats, cap_mask)
1632        };
1633
1634        // 3. Calculate latent dimensions: 2 * (image_size / 16)
1635        let vae_align = 16;
1636        let latent_h = 2 * (height / vae_align);
1637        let latent_w = 2 * (width / vae_align);
1638
1639        // 5. Initialize scheduler
1640        let patch_size = loaded.transformer_cfg.all_patch_size[0];
1641        let image_seq_len = (latent_h / patch_size) * (latent_w / patch_size);
1642        let (mut scheduler, start_index) = build_zimage_scheduler(
1643            req.steps as usize,
1644            image_seq_len,
1645            req.source_image.as_ref().map(|_| req.strength),
1646        );
1647
1648        if req.source_image.is_some() {
1649            tracing::info!(
1650                strength = req.strength,
1651                start_index,
1652                start_sigma = scheduler.sigmas[0],
1653                remaining_sigmas = scheduler.sigmas.len(),
1654                remaining_steps = scheduler.sigmas.len().saturating_sub(1),
1655                "img2img: truncated schedule from strength"
1656            );
1657        }
1658
1659        // 6. Build initial latents — img2img encodes source image, txt2img uses pure noise
1660        let (mut latents, inpaint_ctx) = if let Some(ref source_bytes) = req.source_image {
1661            let start_sigma = scheduler.sigmas[0];
1662
1663            // Encode source image through the pre-loaded VAE
1664            progress.stage_start("Encoding source image (VAE)");
1665            let encode_start = Instant::now();
1666            let vae_encode_device = &loaded.vae_device;
1667            let vae_encode_dtype = if loaded.vae_device.is_cpu() {
1668                DType::F32
1669            } else {
1670                loaded.dtype
1671            };
1672            let source_tensor = img_utils::decode_source_image(
1673                source_bytes,
1674                req.width,
1675                req.height,
1676                img_utils::NormalizeRange::ZeroToOne,
1677                vae_encode_device,
1678                vae_encode_dtype,
1679            )?;
1680            let encoded = loaded.vae.encode(&source_tensor)?;
1681            progress.stage_done("Encoding source image (VAE)", encode_start.elapsed());
1682
1683            let encoded = encoded.to_dtype(loaded.dtype)?.to_device(&loaded.device)?;
1684
1685            let prepared = crate::img2img::prepare_flow_match_img2img(
1686                &encoded,
1687                seed,
1688                &[1, 16, latent_h, latent_w],
1689                start_sigma,
1690                req.mask_image.as_deref(),
1691                latent_h,
1692                latent_w,
1693                &loaded.device,
1694                loaded.dtype,
1695            )?;
1696            (prepared.initial_latents.unsqueeze(2)?, prepared.inpaint_ctx)
1697        } else {
1698            // txt2img: pure noise (B, 16, latent_h, latent_w) → add frame dim
1699            let noise = crate::engine::seeded_randn(
1700                seed,
1701                &[1, 16, latent_h, latent_w],
1702                &loaded.device,
1703                loaded.dtype,
1704            )?;
1705            (noise.unsqueeze(2)?, None)
1706        };
1707
1708        // 7. Denoising loop
1709        let num_steps = scheduler.sigmas.len().saturating_sub(1);
1710        let denoise_label = format!("Denoising ({} steps)", num_steps);
1711        progress.stage_start(&denoise_label);
1712        let denoise_start = Instant::now();
1713
1714        // Scope the transformer borrow so it can be dropped before VAE decode
1715        {
1716            let transformer = loaded
1717                .transformer
1718                .as_ref()
1719                .expect("transformer must be loaded for denoising");
1720
1721            for step in 0..num_steps {
1722                let step_start = Instant::now();
1723                let t = model_timestep(&scheduler);
1724                let t_tensor = Tensor::from_vec(vec![t as f32], (1,), &loaded.device)?
1725                    .to_dtype(loaded.dtype)?;
1726                if zimage_debug_enabled() {
1727                    tracing::debug!(
1728                        step = step + 1,
1729                        total = num_steps,
1730                        sigma = scheduler.current_sigma(),
1731                        timestep = t,
1732                        "{}",
1733                        tensor_stats_summary("latents_in", &latents)?
1734                    );
1735                }
1736
1737                // Forward pass through transformer
1738                let noise_pred = transformer.forward(&latents, &t_tensor, &cap_feats, &cap_mask)?;
1739                if zimage_debug_enabled() {
1740                    tracing::debug!(
1741                        step = step + 1,
1742                        total = num_steps,
1743                        "{}",
1744                        tensor_stats_summary("noise_pred_raw", &noise_pred)?
1745                    );
1746                }
1747
1748                // Negate prediction (Z-Image specific)
1749                let noise_pred = noise_pred.neg()?;
1750
1751                // Remove frame dimension for scheduler: (B, C, 1, H, W) → (B, C, H, W)
1752                let noise_pred_4d = noise_pred.squeeze(2)?;
1753                let latents_4d = latents.squeeze(2)?;
1754
1755                // Scheduler step
1756                let prev_latents = scheduler.step(&noise_pred_4d, &latents_4d)?;
1757
1758                // Add back frame dimension
1759                latents = prev_latents.unsqueeze(2)?;
1760                if zimage_debug_enabled() {
1761                    tracing::debug!(
1762                        step = step + 1,
1763                        total = num_steps,
1764                        sigma_next = scheduler.current_sigma(),
1765                        "{}",
1766                        tensor_stats_summary("latents_out", &latents)?
1767                    );
1768                }
1769
1770                // Inpainting: blend preserved regions back at current noise level
1771                if let Some(ref ctx) = inpaint_ctx {
1772                    let latents_4d = latents.squeeze(2)?;
1773                    let blended = crate::img2img::apply_flow_match_inpaint(
1774                        &latents_4d,
1775                        ctx,
1776                        scheduler.sigmas[step + 1],
1777                    )?;
1778                    latents = blended.unsqueeze(2)?;
1779                }
1780
1781                progress.emit(ProgressEvent::DenoiseStep {
1782                    step: step + 1,
1783                    total: num_steps,
1784                    elapsed: step_start.elapsed(),
1785                });
1786            }
1787        }
1788
1789        progress.stage_done(&denoise_label, denoise_start.elapsed());
1790        tracing::info!("denoising complete");
1791
1792        // Free text embeddings — no longer needed after denoising
1793        drop(cap_feats);
1794        drop(cap_mask);
1795
1796        // Drop the transformer weights from GPU to free VRAM for VAE decode.
1797        // The transformer (~6.6GB for Q8) is only needed during denoising.
1798        // It will be reloaded from disk on the next generate() call.
1799        loaded.transformer = None;
1800        // Synchronize to ensure CUDA's caching allocator reclaims the freed memory
1801        // before VAE decode allocates large im2col workspace buffers (~6GB at 1024x1024).
1802        loaded.device.synchronize()?;
1803        tracing::info!("Z-Image transformer dropped from GPU to free VRAM for VAE decode");
1804
1805        // 8. VAE decode — try GPU first, fall back to CPU on OOM
1806        progress.stage_start("VAE decode");
1807        let vae_start = Instant::now();
1808
1809        // Remove frame dimension: (B, C, 1, H, W) → (B, C, H, W)
1810        let latents_4d = latents.squeeze(2)?;
1811
1812        // Try VAE decode on the pre-assigned device
1813        let image = {
1814            let decode_latents = latents_4d.to_device(&loaded.vae_device)?.to_dtype(
1815                if loaded.vae_device.is_cpu() {
1816                    DType::F32
1817                } else {
1818                    loaded.vae_dtype
1819                },
1820            )?;
1821            match loaded.vae.decode(&decode_latents) {
1822                Ok(img) => img,
1823                Err(e) if loaded.vae_device.is_cuda() => {
1824                    // OOM on GPU — reload VAE on CPU and retry
1825                    let err_msg = format!("{e}");
1826                    if err_msg.contains("OUT_OF_MEMORY") || err_msg.contains("out of memory") {
1827                        tracing::warn!("VAE decode OOM on GPU, falling back to CPU");
1828                        progress.info("VAE decode OOM on GPU — retrying on CPU");
1829                        loaded.device.synchronize()?;
1830                        // Load a fresh VAE on CPU (can't call self.load_vae_cpu() due to borrow)
1831                        let cpu_vae = load_zimage_vae(
1832                            loaded.vae_path.as_path(),
1833                            DType::F32,
1834                            &Device::Cpu,
1835                            progress,
1836                            None,
1837                        )?;
1838                        let cpu_latents =
1839                            latents_4d.to_device(&Device::Cpu)?.to_dtype(DType::F32)?;
1840                        cpu_vae.decode(&cpu_latents)?
1841                    } else {
1842                        return Err(e.into());
1843                    }
1844                }
1845                Err(e) => return Err(e.into()),
1846            }
1847        };
1848
1849        // Post-process: [-1, 1] → [0, 255] (candle z_image utility)
1850        let image = postprocess_image(&image)?;
1851        let image = image.i(0)?; // Remove batch dimension → [3, H, W]
1852
1853        progress.stage_done("VAE decode", vae_start.elapsed());
1854
1855        // 9. Encode to output format
1856        let output_metadata = build_output_metadata(req, seed, None);
1857        let image_bytes = encode_image(
1858            &image,
1859            req.resolved_output_format(),
1860            req.width,
1861            req.height,
1862            output_metadata.as_ref(),
1863        )?;
1864
1865        let generation_time_ms = start.elapsed().as_millis() as u64;
1866        tracing::info!(generation_time_ms, seed, "Z-Image generation complete");
1867
1868        Ok(GenerateResponse {
1869            images: vec![ImageData {
1870                data: image_bytes,
1871                format: req.resolved_output_format(),
1872                width: req.width,
1873                height: req.height,
1874                index: 0,
1875            }],
1876            generation_time_ms,
1877            model: req.model.clone(),
1878            seed_used: seed,
1879            video: None,
1880            gpu: None,
1881        })
1882    }
1883}
1884
1885impl InferenceEngine for ZImageEngine {
1886    fn generate(&mut self, req: &GenerateRequest) -> Result<GenerateResponse> {
1887        self.pending_placement = req.placement.clone();
1888        self.pending_loras = effective_zimage_loras(req);
1889        let result = self.generate_inner(req);
1890        self.pending_placement = None;
1891        self.pending_loras.clear();
1892        result
1893    }
1894
1895    fn model_name(&self) -> &str {
1896        self.base.model_name()
1897    }
1898
1899    fn is_loaded(&self) -> bool {
1900        // Sequential mode is always "ready" — it loads on demand
1901        self.base.is_loaded()
1902    }
1903
1904    fn load(&mut self) -> Result<()> {
1905        ZImageEngine::load(self)
1906    }
1907
1908    fn unload(&mut self) {
1909        self.base.unload();
1910        clear_cache(&self.prompt_cache);
1911    }
1912
1913    fn set_on_progress(&mut self, callback: ProgressCallback) {
1914        self.base.set_on_progress(callback);
1915    }
1916
1917    fn clear_on_progress(&mut self) {
1918        self.base.clear_on_progress();
1919    }
1920
1921    fn model_paths(&self) -> Option<&mold_core::ModelPaths> {
1922        Some(&self.base.paths)
1923    }
1924}
1925
1926#[cfg(test)]
1927mod tests {
1928    use super::*;
1929    use crate::device::should_use_gpu;
1930    use crate::engine::LoadStrategy;
1931    use crate::shared_pool::SharedPool;
1932    use mold_core::ModelPaths;
1933    use std::fs;
1934    use std::path::{Path, PathBuf};
1935    use std::sync::{Arc, Mutex};
1936    use std::time::{SystemTime, UNIX_EPOCH};
1937    use tokenizers::models::bpe::BPE;
1938
1939    fn temp_test_dir(prefix: &str) -> PathBuf {
1940        let suffix = SystemTime::now()
1941            .duration_since(UNIX_EPOCH)
1942            .unwrap()
1943            .as_nanos();
1944        let dir = std::env::temp_dir().join(format!("{prefix}-{}-{suffix}", std::process::id()));
1945        fs::create_dir_all(&dir).unwrap();
1946        dir
1947    }
1948
1949    fn touch(dir: &Path, name: &str) -> PathBuf {
1950        let path = dir.join(name);
1951        fs::write(&path, b"test").unwrap();
1952        path
1953    }
1954
1955    fn zimage_model_paths(
1956        transformer: PathBuf,
1957        transformer_shards: Vec<PathBuf>,
1958        vae: PathBuf,
1959        text_tokenizer: Option<PathBuf>,
1960    ) -> ModelPaths {
1961        ModelPaths {
1962            transformer,
1963            transformer_shards,
1964            vae,
1965            spatial_upscaler: None,
1966            temporal_upscaler: None,
1967            distilled_lora: None,
1968            t5_encoder: None,
1969            clip_encoder: None,
1970            t5_tokenizer: None,
1971            clip_tokenizer: None,
1972            clip_encoder_2: None,
1973            clip_tokenizer_2: None,
1974            text_encoder_files: vec![],
1975            text_tokenizer,
1976            decoder: None,
1977        }
1978    }
1979
1980    #[test]
1981    fn zimage_safetensors_backend_accepts_civitai_diffusion_prefix() {
1982        use candle_nn::var_builder::SimpleBackend;
1983        use safetensors::tensor::{serialize_to_file, Dtype as SafeDtype, TensorView};
1984        use std::collections::HashMap;
1985
1986        fn f32_bytes(values: &[f32]) -> Vec<u8> {
1987            values
1988                .iter()
1989                .flat_map(|value| value.to_le_bytes())
1990                .collect()
1991        }
1992
1993        let dir = temp_test_dir("mold-zimage-prefix-backend");
1994        let path = dir.join("zimage.safetensors");
1995        let data = f32_bytes(&[42.0]);
1996        let qkv = f32_bytes(&[1.0, 2.0, 3.0]);
1997        let out = f32_bytes(&[7.0]);
1998        let q_norm = f32_bytes(&[8.0]);
1999        let k_norm = f32_bytes(&[9.0]);
2000        let mut tensors = HashMap::new();
2001        tensors.insert(
2002            format!("{ZIMAGE_SINGLE_FILE_PREFIX}t_embedder.mlp.0.weight"),
2003            TensorView::new(SafeDtype::F32, vec![1, 1], data.as_slice()).unwrap(),
2004        );
2005        tensors.insert(
2006            format!("{ZIMAGE_SINGLE_FILE_PREFIX}x_embedder.weight"),
2007            TensorView::new(SafeDtype::F32, vec![1, 1], data.as_slice()).unwrap(),
2008        );
2009        tensors.insert(
2010            format!("{ZIMAGE_SINGLE_FILE_PREFIX}noise_refiner.0.attention.qkv.weight"),
2011            TensorView::new(SafeDtype::F32, vec![3, 1], qkv.as_slice()).unwrap(),
2012        );
2013        tensors.insert(
2014            format!("{ZIMAGE_SINGLE_FILE_PREFIX}noise_refiner.0.attention.out.weight"),
2015            TensorView::new(SafeDtype::F32, vec![1, 1], out.as_slice()).unwrap(),
2016        );
2017        tensors.insert(
2018            format!("{ZIMAGE_SINGLE_FILE_PREFIX}noise_refiner.0.attention.q_norm.weight"),
2019            TensorView::new(SafeDtype::F32, vec![1], q_norm.as_slice()).unwrap(),
2020        );
2021        tensors.insert(
2022            format!("{ZIMAGE_SINGLE_FILE_PREFIX}noise_refiner.0.attention.k_norm.weight"),
2023            TensorView::new(SafeDtype::F32, vec![1], k_norm.as_slice()).unwrap(),
2024        );
2025        serialize_to_file(&tensors, &None, &path).unwrap();
2026
2027        let st = unsafe { candle_core::safetensors::MmapedSafetensors::multi(&[path.as_path()]) }
2028            .unwrap();
2029        let backend = ZImageSafetensorsBackend::new(st);
2030        assert!(backend.contains_tensor("t_embedder.mlp.0.weight"));
2031        let tensor = backend
2032            .get_unchecked("t_embedder.mlp.0.weight", DType::F32, &Device::Cpu)
2033            .unwrap();
2034        assert_eq!(tensor.to_vec2::<f32>().unwrap(), vec![vec![42.0]]);
2035        assert!(backend.contains_tensor("all_x_embedder.2-1.weight"));
2036        let alias_tensor = backend
2037            .get_unchecked("all_x_embedder.2-1.weight", DType::F32, &Device::Cpu)
2038            .unwrap();
2039        assert_eq!(alias_tensor.to_vec2::<f32>().unwrap(), vec![vec![42.0]]);
2040        assert!(backend.contains_tensor("noise_refiner.0.attention.to_q.weight"));
2041        assert!(backend.contains_tensor("noise_refiner.0.attention.to_k.weight"));
2042        assert!(backend.contains_tensor("noise_refiner.0.attention.to_v.weight"));
2043        let k = backend
2044            .get(
2045                Shape::from((1, 1)),
2046                "noise_refiner.0.attention.to_k.weight",
2047                candle_nn::Init::Const(0.0),
2048                DType::F32,
2049                &Device::Cpu,
2050            )
2051            .unwrap();
2052        assert_eq!(k.to_vec2::<f32>().unwrap(), vec![vec![2.0]]);
2053        let out = backend
2054            .get_unchecked(
2055                "noise_refiner.0.attention.to_out.0.weight",
2056                DType::F32,
2057                &Device::Cpu,
2058            )
2059            .unwrap();
2060        assert_eq!(out.to_vec2::<f32>().unwrap(), vec![vec![7.0]]);
2061        let q_norm = backend
2062            .get_unchecked(
2063                "noise_refiner.0.attention.norm_q.weight",
2064                DType::F32,
2065                &Device::Cpu,
2066            )
2067            .unwrap();
2068        assert_eq!(q_norm.to_vec1::<f32>().unwrap(), vec![8.0]);
2069        let k_norm = backend
2070            .get_unchecked(
2071                "noise_refiner.0.attention.norm_k.weight",
2072                DType::F32,
2073                &Device::Cpu,
2074            )
2075            .unwrap();
2076        assert_eq!(k_norm.to_vec1::<f32>().unwrap(), vec![9.0]);
2077
2078        let _ = std::fs::remove_dir_all(dir);
2079    }
2080
2081    #[test]
2082    fn zimage_vae_backend_accepts_bare_ldm_vae_keys() {
2083        use candle_nn::var_builder::SimpleBackend;
2084        use safetensors::tensor::{serialize_to_file, Dtype as SafeDtype, TensorView};
2085        use std::collections::HashMap;
2086
2087        fn f32_bytes(values: &[f32]) -> Vec<u8> {
2088            values
2089                .iter()
2090                .flat_map(|value| value.to_le_bytes())
2091                .collect()
2092        }
2093
2094        let dir = temp_test_dir("mold-zimage-vae-backend");
2095        let path = dir.join("vae.safetensors");
2096        let norm = f32_bytes(&[5.0]);
2097        let conv = f32_bytes(&[7.0]);
2098        let attn_q = f32_bytes(&[1.0, 2.0, 3.0, 4.0]);
2099        let mut tensors = HashMap::new();
2100        tensors.insert(
2101            "encoder.down.0.block.0.norm1.weight".to_string(),
2102            TensorView::new(SafeDtype::F32, vec![1], norm.as_slice()).unwrap(),
2103        );
2104        tensors.insert(
2105            "decoder.norm_out.weight".to_string(),
2106            TensorView::new(SafeDtype::F32, vec![1], conv.as_slice()).unwrap(),
2107        );
2108        tensors.insert(
2109            "encoder.mid.attn_1.q.weight".to_string(),
2110            TensorView::new(SafeDtype::F32, vec![2, 2, 1, 1], attn_q.as_slice()).unwrap(),
2111        );
2112        serialize_to_file(&tensors, &None, &path).unwrap();
2113
2114        let st = unsafe { candle_core::safetensors::MmapedSafetensors::multi(&[path.as_path()]) }
2115            .unwrap();
2116        let backend = ZImageVaeSafetensorsBackend::new(st);
2117
2118        assert!(backend.contains_tensor("encoder.down_blocks.0.resnets.0.norm1.weight"));
2119        let norm = backend
2120            .get_unchecked(
2121                "encoder.down_blocks.0.resnets.0.norm1.weight",
2122                DType::F32,
2123                &Device::Cpu,
2124            )
2125            .unwrap();
2126        assert_eq!(norm.to_vec1::<f32>().unwrap(), vec![5.0]);
2127
2128        assert!(backend.contains_tensor("decoder.conv_norm_out.weight"));
2129        let conv = backend
2130            .get_unchecked("decoder.conv_norm_out.weight", DType::F32, &Device::Cpu)
2131            .unwrap();
2132        assert_eq!(conv.to_vec1::<f32>().unwrap(), vec![7.0]);
2133        let q = backend
2134            .get(
2135                Shape::from((2, 2)),
2136                "encoder.mid_block.attentions.0.to_q.weight",
2137                candle_nn::Init::Const(0.0),
2138                DType::F32,
2139                &Device::Cpu,
2140            )
2141            .unwrap();
2142        assert_eq!(
2143            q.to_vec2::<f32>().unwrap(),
2144            vec![vec![1.0, 2.0], vec![3.0, 4.0]]
2145        );
2146
2147        let _ = std::fs::remove_dir_all(dir);
2148    }
2149
2150    #[test]
2151    fn zimage_vae_cpu_tensor_backend_preserves_aliases_and_reshape() {
2152        use candle_nn::var_builder::SimpleBackend;
2153        use std::collections::HashMap;
2154
2155        let device = Device::Cpu;
2156        let mut tensors = HashMap::new();
2157        tensors.insert(
2158            "encoder.down.0.block.0.norm1.weight".to_string(),
2159            Tensor::new(&[5.0f32], &device).unwrap(),
2160        );
2161        tensors.insert(
2162            "encoder.mid.attn_1.q.weight".to_string(),
2163            Tensor::new(&[1.0f32, 2.0, 3.0, 4.0], &device)
2164                .unwrap()
2165                .reshape((2, 2, 1, 1))
2166                .unwrap(),
2167        );
2168
2169        let backend = ZImageVaeSafetensorsBackend::from_cpu_tensors(Arc::new(tensors));
2170
2171        assert!(backend.contains_tensor("encoder.down_blocks.0.resnets.0.norm1.weight"));
2172        let norm = backend
2173            .get_unchecked(
2174                "encoder.down_blocks.0.resnets.0.norm1.weight",
2175                DType::F32,
2176                &device,
2177            )
2178            .unwrap();
2179        assert_eq!(norm.to_vec1::<f32>().unwrap(), vec![5.0]);
2180
2181        let q = backend
2182            .get(
2183                Shape::from((2, 2)),
2184                "encoder.mid_block.attentions.0.to_q.weight",
2185                candle_nn::Init::Const(0.0),
2186                DType::F32,
2187                &device,
2188            )
2189            .unwrap();
2190        assert_eq!(
2191            q.to_vec2::<f32>().unwrap(),
2192            vec![vec![1.0, 2.0], vec![3.0, 4.0]]
2193        );
2194    }
2195
2196    #[test]
2197    fn zimage_loads_vae_tensors_through_shared_pool() {
2198        use safetensors::tensor::{serialize_to_file, Dtype as SafeDtype, TensorView};
2199        use std::collections::HashMap;
2200
2201        let dir = temp_test_dir("mold-zimage-vae-pool");
2202        let vae_path = dir.join("vae.safetensors");
2203        let weight = 1.0f32.to_le_bytes();
2204        let mut tensors = HashMap::new();
2205        tensors.insert(
2206            "encoder.conv_in.weight".to_string(),
2207            TensorView::new(SafeDtype::F32, vec![1], weight.as_slice()).unwrap(),
2208        );
2209        serialize_to_file(&tensors, &None, &vae_path).unwrap();
2210
2211        let shared_pool = Arc::new(Mutex::new(SharedPool::new()));
2212        let pooled = shared_pool
2213            .lock()
2214            .unwrap()
2215            .load_safetensors_cpu_tensors(std::slice::from_ref(&vae_path))
2216            .unwrap()
2217            .unwrap();
2218
2219        let engine = ZImageEngine::new(
2220            "z-image-turbo:q4".to_string(),
2221            zimage_model_paths(
2222                dir.join("transformer.gguf"),
2223                vec![],
2224                vae_path,
2225                Some(dir.join("tokenizer.json")),
2226            ),
2227            None,
2228            LoadStrategy::Sequential,
2229            0,
2230            false,
2231            Some(shared_pool),
2232        );
2233
2234        let loaded = engine.load_vae_cpu_tensors().unwrap().unwrap();
2235
2236        assert!(Arc::ptr_eq(&pooled, &loaded));
2237        fs::remove_dir_all(dir).ok();
2238    }
2239
2240    #[test]
2241    fn zimage_recipe_text_encoder_defaults_to_bf16_variant() {
2242        let recipe_paths = vec![std::path::PathBuf::from(
2243            "/models/cv-2442439/z-image/civitai/2442439/zImageTurbo_turbo_txt.safetensors",
2244        )];
2245        let shared_paths = vec![std::path::PathBuf::from(
2246            "/models/shared/z-image/text_encoder/model-00001-of-00003.safetensors",
2247        )];
2248
2249        assert_eq!(zimage_qwen3_preference(None, &recipe_paths), Some("bf16"));
2250        assert_eq!(zimage_qwen3_preference(None, &shared_paths), None);
2251        assert_eq!(
2252            zimage_qwen3_preference(Some("q8"), &recipe_paths),
2253            Some("q8")
2254        );
2255        assert_eq!(
2256            zimage_qwen3_preference(Some("auto"), &recipe_paths),
2257            Some("auto")
2258        );
2259    }
2260
2261    #[test]
2262    fn latent_dimensions() {
2263        // 1024px → 2 * (1024 / 16) = 128
2264        assert_eq!(2 * (1024 / 16), 128);
2265        // 512px → 2 * (512 / 16) = 64
2266        assert_eq!(2 * (512 / 16), 64);
2267        // 768px → 2 * (768 / 16) = 96
2268        assert_eq!(2 * (768 / 16), 96);
2269    }
2270
2271    // --- VRAM threshold decision tests (with drop-and-reload) ---
2272
2273    #[test]
2274    fn qwen3_on_gpu_on_24gb_with_q8_drop_reload() {
2275        // Q8 transformer (6.6GB) on 24GB card → ~17GB free
2276        // With drop-and-reload, threshold is 10.2GB → fits on GPU!
2277        assert!(should_use_gpu(
2278            true,
2279            false,
2280            17_000_000_000,
2281            QWEN3_FP16_VRAM_THRESHOLD
2282        ));
2283    }
2284
2285    #[test]
2286    fn qwen3_on_gpu_on_24gb_with_q4_drop_reload() {
2287        // Q4 transformer (3.9GB) on 24GB card → ~19GB free
2288        // With drop-and-reload, easily fits on GPU
2289        assert!(should_use_gpu(
2290            true,
2291            false,
2292            19_000_000_000,
2293            QWEN3_FP16_VRAM_THRESHOLD
2294        ));
2295    }
2296
2297    #[test]
2298    fn qwen3_on_cpu_with_bf16_transformer() {
2299        // BF16 transformer (24.6GB) on 24GB card → ~0GB free
2300        // Even with drop-and-reload, can't fit
2301        assert!(!should_use_gpu(
2302            true,
2303            false,
2304            400_000_000,
2305            QWEN3_FP16_VRAM_THRESHOLD
2306        ));
2307    }
2308
2309    #[test]
2310    fn qwen3_on_gpu_on_48gb_card() {
2311        assert!(should_use_gpu(
2312            true,
2313            false,
2314            40_000_000_000,
2315            QWEN3_FP16_VRAM_THRESHOLD
2316        ));
2317    }
2318
2319    #[test]
2320    fn qwen3_on_gpu_on_metal() {
2321        // Metal with no memory info falls back to true
2322        assert!(should_use_gpu(false, true, 0, QWEN3_FP16_VRAM_THRESHOLD));
2323    }
2324
2325    #[test]
2326    fn vae_on_gpu_when_plenty_of_vram() {
2327        assert!(should_use_gpu(
2328            true,
2329            false,
2330            17_000_000_000,
2331            VAE_DECODE_VRAM_THRESHOLD
2332        ));
2333    }
2334
2335    #[test]
2336    fn eager_vae_weight_load_threshold_is_below_decode_workspace_threshold() {
2337        const {
2338            assert!(VAE_WEIGHT_LOAD_VRAM_THRESHOLD < VAE_DECODE_VRAM_THRESHOLD);
2339        }
2340        assert!(should_use_gpu(
2341            true,
2342            false,
2343            1_000_000_000,
2344            VAE_WEIGHT_LOAD_VRAM_THRESHOLD
2345        ));
2346    }
2347
2348    #[test]
2349    fn vae_on_cpu_when_vram_tight() {
2350        assert!(!should_use_gpu(
2351            true,
2352            false,
2353            5_400_000_000,
2354            VAE_DECODE_VRAM_THRESHOLD
2355        ));
2356    }
2357
2358    #[test]
2359    fn vae_on_gpu_on_metal() {
2360        // Metal with no memory info falls back to true
2361        assert!(should_use_gpu(false, true, 0, VAE_DECODE_VRAM_THRESHOLD));
2362    }
2363
2364    // --- Threshold sanity checks ---
2365
2366    #[test]
2367    fn qwen3_threshold_allows_gpu_on_24gb_with_quantized_xformer() {
2368        // Key improvement: with drop-and-reload, BF16 Qwen3 fits on GPU
2369        // when quantized transformer is used on 24GB cards
2370        let threshold = std::hint::black_box(QWEN3_FP16_VRAM_THRESHOLD);
2371        assert!(threshold < 17_000_000_000);
2372    }
2373
2374    #[test]
2375    fn qwen3_threshold_exceeds_encoder_size() {
2376        let threshold = std::hint::black_box(QWEN3_FP16_VRAM_THRESHOLD);
2377        assert!(threshold > 8_200_000_000);
2378    }
2379
2380    #[test]
2381    fn vae_threshold_accounts_for_decode_workspace() {
2382        let threshold = std::hint::black_box(VAE_DECODE_VRAM_THRESHOLD);
2383        assert!(threshold > 160_000_000);
2384        assert!(threshold < 15_000_000_000);
2385    }
2386
2387    #[test]
2388    fn zimage_scheduler_uses_shifted_reference_sigmas() {
2389        let image_seq_len = 1024;
2390        let (full, _) = build_zimage_scheduler(9, image_seq_len, None);
2391        let (scheduler, start_index) = build_zimage_scheduler(9, image_seq_len, Some(0.5));
2392        let expected_sigmas = full.sigmas[start_index..].to_vec();
2393        let expected_timesteps = expected_sigmas[..expected_sigmas.len() - 1]
2394            .iter()
2395            .map(|sigma| sigma * 1000.0)
2396            .collect::<Vec<_>>();
2397
2398        assert_eq!(start_index, crate::img2img::img2img_start_index(9, 0.5));
2399        assert_eq!(scheduler.sigmas, expected_sigmas);
2400        assert_eq!(scheduler.timesteps, expected_timesteps);
2401        assert_eq!(scheduler.sigmas.last().copied(), Some(0.0));
2402    }
2403
2404    #[test]
2405    fn zimage_model_timestep_matches_scheduler_timesteps() {
2406        let (scheduler, _) = build_zimage_scheduler(9, 1024, Some(0.5));
2407        let t = model_timestep(&scheduler);
2408        assert!(
2409            (t - (1.0 - scheduler.sigmas[0])).abs() < 1e-10,
2410            "expected model timestep to match 1-sigma semantics, got {t} vs {}",
2411            1.0 - scheduler.sigmas[0]
2412        );
2413    }
2414
2415    #[test]
2416    fn zimage_img2img_source_decode_uses_vae_native_zero_to_one_range() {
2417        let source = include_str!("pipeline.rs")
2418            .split("#[cfg(test)]\nmod tests")
2419            .next()
2420            .expect("pipeline source should include production section");
2421        let decode_sites = source
2422            .split("let source_tensor = img_utils::decode_source_image(")
2423            .skip(1)
2424            .collect::<Vec<_>>();
2425
2426        assert_eq!(decode_sites.len(), 2);
2427        for site in decode_sites {
2428            let args = site
2429                .split(")?;")
2430                .next()
2431                .expect("source decode call should terminate");
2432            assert!(
2433                args.contains("img_utils::NormalizeRange::ZeroToOne"),
2434                "Z-Image source-image encoding must use the VAE-native [0, 1] range"
2435            );
2436            assert!(
2437                !args.contains("img_utils::NormalizeRange::MinusOneToOne"),
2438                "Z-Image source-image encoding must not use [-1, 1] normalization"
2439            );
2440        }
2441    }
2442
2443    #[test]
2444    fn zimage_zero_strength_preserves_terminal_zero_only() {
2445        let (scheduler, start_index) = build_zimage_scheduler(9, 1024, Some(0.0));
2446
2447        assert_eq!(start_index, 9);
2448        assert_eq!(scheduler.sigmas, vec![0.0]);
2449        assert!(scheduler.timesteps.is_empty());
2450    }
2451
2452    #[test]
2453    fn tensor_stats_summary_reports_expected_values() {
2454        let tensor =
2455            Tensor::from_vec(vec![1.0f32, -1.0, 3.0, -3.0], (1, 1, 2, 2), &Device::Cpu).unwrap();
2456        let summary = tensor_stats_summary("probe", &tensor).unwrap();
2457
2458        assert!(summary.contains("probe:"));
2459        assert!(summary.contains("mean=0.00000"));
2460        assert!(summary.contains("min=-3.00000"));
2461        assert!(summary.contains("max=3.00000"));
2462        assert!(summary.contains("rms=2.23607"));
2463    }
2464
2465    #[test]
2466    fn zimage_transformer_paths_prefer_shards_when_present() {
2467        let dir = temp_test_dir("mold-zimage-shards");
2468        let shard_a = touch(&dir, "transformer-00001-of-00002.safetensors");
2469        let shard_b = touch(&dir, "transformer-00002-of-00002.safetensors");
2470        let engine = ZImageEngine::new(
2471            "z-image-turbo:bf16".to_string(),
2472            zimage_model_paths(
2473                dir.join("transformer.safetensors"),
2474                vec![shard_a.clone(), shard_b.clone()],
2475                dir.join("vae.safetensors"),
2476                Some(dir.join("tokenizer.json")),
2477            ),
2478            None,
2479            LoadStrategy::Sequential,
2480            0,
2481            false,
2482            None,
2483        );
2484
2485        assert_eq!(engine.transformer_paths(), vec![shard_a, shard_b]);
2486
2487        fs::remove_dir_all(dir).ok();
2488    }
2489
2490    #[test]
2491    fn zimage_validate_paths_accepts_existing_files() {
2492        let dir = temp_test_dir("mold-zimage-validate-ok");
2493        let shard_a = touch(&dir, "transformer-00001-of-00002.safetensors");
2494        let shard_b = touch(&dir, "transformer-00002-of-00002.safetensors");
2495        let vae = touch(&dir, "vae.safetensors");
2496        let tokenizer = touch(&dir, "tokenizer.json");
2497        let gguf = touch(&dir, "transformer.gguf");
2498
2499        let sharded = ZImageEngine::new(
2500            "z-image-turbo:bf16".to_string(),
2501            zimage_model_paths(
2502                dir.join("transformer.safetensors"),
2503                vec![shard_a, shard_b],
2504                vae.clone(),
2505                Some(tokenizer.clone()),
2506            ),
2507            None,
2508            LoadStrategy::Sequential,
2509            0,
2510            false,
2511            None,
2512        );
2513        assert_eq!(sharded.validate_paths().unwrap(), tokenizer);
2514        assert!(!sharded.detect_is_gguf());
2515
2516        let quantized = ZImageEngine::new(
2517            "z-image-turbo:q4".to_string(),
2518            zimage_model_paths(gguf, vec![], vae, Some(dir.join("tokenizer.json"))),
2519            None,
2520            LoadStrategy::Sequential,
2521            0,
2522            false,
2523            None,
2524        );
2525        assert!(quantized.detect_is_gguf());
2526
2527        fs::remove_dir_all(dir).ok();
2528    }
2529
2530    #[test]
2531    fn zimage_lora_requests_use_sequential_generation_path() {
2532        let dir = temp_test_dir("mold-zimage-lora-sequential");
2533        let mut engine = ZImageEngine::new(
2534            "z-image-turbo:q8".to_string(),
2535            zimage_model_paths(
2536                dir.join("transformer.gguf"),
2537                vec![],
2538                dir.join("vae.safetensors"),
2539                Some(dir.join("tokenizer.json")),
2540            ),
2541            None,
2542            LoadStrategy::Eager,
2543            0,
2544            false,
2545            None,
2546        );
2547        engine.pending_loras = vec![LoraWeight {
2548            path: dir.join("adapter.safetensors").display().to_string(),
2549            scale: 1.0,
2550        }];
2551
2552        assert!(
2553            engine.uses_sequential_generate_path(),
2554            "Z-Image LoRA requests should use staged load-use-drop generation \
2555             so VAE/text encoders are not co-resident with the LoRA-merged transformer"
2556        );
2557
2558        fs::remove_dir_all(dir).ok();
2559    }
2560
2561    #[test]
2562    fn zimage_sequential_path_drops_eager_components_before_generation() {
2563        let source = include_str!("pipeline.rs");
2564        let sequential_branch = source
2565            .split("// Eager mode: use pre-loaded components")
2566            .next()
2567            .expect("generate_inner should contain eager-mode marker");
2568
2569        assert!(
2570            sequential_branch.contains("self.base.unload();")
2571                && sequential_branch.contains("return self.generate_sequential(req);"),
2572            "Z-Image LoRA/offload sequential generation must drop eager-loaded \
2573             components before loading staged components"
2574        );
2575    }
2576
2577    #[test]
2578    fn zimage_eager_path_reloads_after_sequential_generation_unloads_components() {
2579        let source = include_str!("pipeline.rs");
2580        let eager_branch = source
2581            .split("// Eager mode: use pre-loaded components")
2582            .nth(1)
2583            .expect("generate_inner should contain eager-mode branch");
2584        let reload_idx = eager_branch
2585            .find("self.load()?;")
2586            .expect("eager branch should reload an unloaded cached engine");
2587        let guard_idx = eager_branch
2588            .find("bail!(\"model not loaded")
2589            .expect("eager branch should retain a final loaded-state guard");
2590
2591        assert!(
2592            reload_idx < guard_idx,
2593            "Z-Image eager generation must reload after a prior LoRA/offload \
2594             sequential request unloads cached components"
2595        );
2596    }
2597
2598    #[test]
2599    fn zimage_forced_offload_uses_sequential_generation_path() {
2600        let dir = temp_test_dir("mold-zimage-offload-sequential");
2601        let engine = ZImageEngine::new(
2602            "z-image-turbo:bf16".to_string(),
2603            zimage_model_paths(
2604                dir.join("transformer.safetensors"),
2605                vec![],
2606                dir.join("vae.safetensors"),
2607                Some(dir.join("tokenizer.json")),
2608            ),
2609            None,
2610            LoadStrategy::Eager,
2611            0,
2612            true,
2613            None,
2614        );
2615
2616        assert!(
2617            engine.uses_sequential_generate_path(),
2618            "Z-Image --offload requests must reach the engine and select the \
2619             staged generation path instead of being silently ignored"
2620        );
2621
2622        fs::remove_dir_all(dir).ok();
2623    }
2624
2625    #[test]
2626    fn zimage_offload_decision_gates_current_unsupported_cases() {
2627        assert_eq!(
2628            zimage_offload_decision(false, false, false),
2629            ZImageOffloadDecision::Disabled
2630        );
2631        assert_eq!(
2632            zimage_offload_decision(true, false, false),
2633            ZImageOffloadDecision::Selected
2634        );
2635        assert!(matches!(
2636            zimage_offload_decision(true, true, false),
2637            ZImageOffloadDecision::Unsupported(reason)
2638                if reason.contains("GGUF variants")
2639        ));
2640        assert!(matches!(
2641            zimage_offload_decision(true, false, true),
2642            ZImageOffloadDecision::Unsupported(reason)
2643                if reason.contains("LoRA")
2644        ));
2645    }
2646
2647    #[test]
2648    fn zimage_selected_bf16_offload_reaches_runtime_loader() {
2649        let dir = temp_test_dir("mold-zimage-offload-loader");
2650        let mut engine = ZImageEngine::new(
2651            "z-image-turbo:bf16".to_string(),
2652            zimage_model_paths(
2653                touch(&dir, "transformer.safetensors"),
2654                vec![],
2655                touch(&dir, "vae.safetensors"),
2656                Some(touch(&dir, "tokenizer.json")),
2657            ),
2658            None,
2659            LoadStrategy::Sequential,
2660            0,
2661            true,
2662            None,
2663        );
2664        let req = GenerateRequest {
2665            prompt: "a cat".to_string(),
2666            negative_prompt: None,
2667            model: "z-image-turbo:bf16".to_string(),
2668            width: 64,
2669            height: 64,
2670            steps: 1,
2671            guidance: 0.0,
2672            seed: Some(1),
2673            batch_size: 1,
2674            output_format: None,
2675            embed_metadata: None,
2676            scheduler: None,
2677            cfg_plus: None,
2678            source_image: None,
2679            edit_images: None,
2680            strength: 1.0,
2681            mask_image: None,
2682            control_image: None,
2683            control_model: None,
2684            control_scale: 1.0,
2685            expand: None,
2686            original_prompt: None,
2687            lora: None,
2688            frames: None,
2689            fps: None,
2690            upscale_model: None,
2691            gif_preview: false,
2692            enable_audio: None,
2693            audio_file: None,
2694            audio_file_path: None,
2695            source_video: None,
2696            source_video_path: None,
2697            keyframes: None,
2698            pipeline: None,
2699            loras: None,
2700            retake_range: None,
2701            spatial_upscale: None,
2702            temporal_upscale: None,
2703            placement: None,
2704        };
2705
2706        let err = engine.generate_sequential(&req).unwrap_err().to_string();
2707
2708        assert!(
2709            !err.contains("streaming is not implemented yet"),
2710            "selected BF16 offload must reach the runtime loader, got: {err}"
2711        );
2712        fs::remove_dir_all(dir).ok();
2713    }
2714
2715    #[test]
2716    fn zimage_validate_paths_requires_text_tokenizer() {
2717        let dir = temp_test_dir("mold-zimage-validate-missing");
2718        let engine = ZImageEngine::new(
2719            "z-image-turbo:q4".to_string(),
2720            zimage_model_paths(
2721                dir.join("transformer.gguf"),
2722                vec![],
2723                dir.join("vae.safetensors"),
2724                None,
2725            ),
2726            None,
2727            LoadStrategy::Sequential,
2728            0,
2729            false,
2730            None,
2731        );
2732
2733        let err = engine.validate_paths().unwrap_err();
2734        assert!(err.to_string().contains("text tokenizer path required"));
2735
2736        fs::remove_dir_all(dir).ok();
2737    }
2738
2739    #[test]
2740    fn zimage_loads_qwen3_tokenizer_through_shared_pool() {
2741        let dir = temp_test_dir("mold-zimage-tokenizer-pool");
2742        let tokenizer_path = dir.join("tokenizer.json");
2743        tokenizers::Tokenizer::new(BPE::default())
2744            .save(&tokenizer_path, false)
2745            .unwrap();
2746
2747        let shared_pool = Arc::new(Mutex::new(SharedPool::new()));
2748        let pooled = shared_pool
2749            .lock()
2750            .unwrap()
2751            .load_tokenizer(&tokenizer_path)
2752            .unwrap();
2753
2754        let engine = ZImageEngine::new(
2755            "z-image-turbo:q4".to_string(),
2756            zimage_model_paths(
2757                dir.join("transformer.gguf"),
2758                vec![],
2759                dir.join("vae.safetensors"),
2760                Some(tokenizer_path.clone()),
2761            ),
2762            None,
2763            LoadStrategy::Sequential,
2764            0,
2765            false,
2766            Some(shared_pool),
2767        );
2768
2769        let loaded = engine.load_text_tokenizer(&tokenizer_path).unwrap();
2770
2771        assert!(Arc::ptr_eq(&pooled, &loaded));
2772        fs::remove_dir_all(dir).ok();
2773    }
2774}