1use anyhow::{bail, Result};
2use candle_core::{DType, Device, IndexOp, Shape, Tensor};
3use candle_transformers::models::z_image::{
4 calculate_shift, postprocess_image, AutoEncoderKL, Config, FlowMatchEulerDiscreteScheduler,
5 SchedulerConfig, VaeConfig,
6};
7use candle_transformers::quantized_var_builder;
8use mold_core::{GenerateRequest, GenerateResponse, ImageData, LoraWeight, ModelPaths};
9use std::borrow::Cow;
10use std::collections::{BTreeMap, HashMap};
11use std::path::Path;
12use std::sync::{Arc, Mutex};
13use std::time::Instant;
14use tokenizers::Tokenizer;
15
16use super::gguf_dense::load_gguf_dense_transformer;
17use super::transformer::{MoldZImageTransformer2DModel, ZImageTransformer};
18use crate::cache::{
19 clear_cache, get_or_insert_cached_tensor, prompt_text_key, restore_cached_tensor, CachedTensor,
20 LruCache, DEFAULT_PROMPT_CACHE_CAPACITY,
21};
22use crate::device::{
23 check_memory_budget, effective_device_ref, fmt_gb, free_vram_bytes, memory_status_string,
24 preflight_memory_check, should_use_gpu, usable_free_vram_bytes,
25};
26#[cfg(test)]
29use crate::device::QWEN3_FP16_VRAM_THRESHOLD;
30use crate::encoders;
31use crate::engine::{rand_seed, InferenceEngine, LoadStrategy};
32use crate::engine_base::EngineBase;
33use crate::image::{build_output_metadata, encode_image};
34use crate::img_utils;
35use crate::progress::{ProgressCallback, ProgressEvent, ProgressReporter};
36
37const VAE_DECODE_VRAM_THRESHOLD: u64 = 6_500_000_000;
41const VAE_WEIGHT_LOAD_VRAM_THRESHOLD: u64 = 600_000_000;
45
46const BASE_IMAGE_SEQ_LEN: usize = 256;
48const MAX_IMAGE_SEQ_LEN: usize = 4096;
49const ZIMAGE_SINGLE_FILE_PREFIX: &str = "model.diffusion_model.";
50
51struct ZImageSafetensorsBackend {
52 st: candle_core::safetensors::MmapedSafetensors,
53}
54
55impl ZImageSafetensorsBackend {
56 fn new(st: candle_core::safetensors::MmapedSafetensors) -> Self {
57 Self { st }
58 }
59
60 fn resolve_stored_name<'a>(&'a self, name: &'a str) -> Option<Cow<'a, str>> {
61 if self.st.get(name).is_ok() {
62 return Some(Cow::Borrowed(name));
63 }
64 if let Some(alias) = zimage_safetensors_alias(name) {
65 if self.st.get(alias.as_ref()).is_ok() {
66 return Some(alias);
67 }
68 }
69 let prefixed = format!("{ZIMAGE_SINGLE_FILE_PREFIX}{name}");
70 if self.st.get(&prefixed).is_ok() {
71 return Some(Cow::Owned(prefixed));
72 }
73 if let Some(alias) = zimage_safetensors_alias(name) {
74 let prefixed_alias = format!("{ZIMAGE_SINGLE_FILE_PREFIX}{}", alias.as_ref());
75 if self.st.get(&prefixed_alias).is_ok() {
76 return Some(Cow::Owned(prefixed_alias));
77 }
78 }
79 None
80 }
81
82 fn stored_name<'a>(&'a self, name: &'a str) -> Cow<'a, str> {
83 self.resolve_stored_name(name)
84 .unwrap_or(Cow::Borrowed(name))
85 }
86
87 fn load_cast(&self, name: &str, dtype: DType, dev: &Device) -> candle_core::Result<Tensor> {
88 let stored_name = self.stored_name(name);
89 let tensor = self.st.load(stored_name.as_ref(), dev)?;
90 if tensor.dtype() != dtype {
91 tensor.to_dtype(dtype)
92 } else {
93 Ok(tensor)
94 }
95 }
96
97 fn load_tensor(
98 &self,
99 name: &str,
100 expected_shape: Option<&Shape>,
101 dtype: DType,
102 dev: &Device,
103 ) -> candle_core::Result<Tensor> {
104 if let Some((source_name, component)) = zimage_qkv_request(name) {
105 return self.load_qkv_split(&source_name, component, expected_shape, dtype, dev);
106 }
107 self.load_cast(name, dtype, dev)
108 }
109
110 fn load_qkv_split(
111 &self,
112 source_name: &str,
113 component: usize,
114 expected_shape: Option<&Shape>,
115 dtype: DType,
116 dev: &Device,
117 ) -> candle_core::Result<Tensor> {
118 let qkv = self.load_cast(source_name, dtype, dev)?;
119 let rows = qkv.dim(0)?;
120 let split_rows = expected_shape
121 .and_then(|shape| shape.dims().first().copied())
122 .unwrap_or(rows / 3);
123 if component >= 3 || split_rows == 0 || rows != split_rows * 3 {
124 return Err(candle_core::Error::msg(format!(
125 "invalid fused QKV shape for {source_name}: rows={rows}, split_rows={split_rows}"
126 )));
127 }
128 qkv.narrow(0, component * split_rows, split_rows)?
129 .contiguous()
130 }
131}
132
133impl candle_nn::var_builder::SimpleBackend for ZImageSafetensorsBackend {
134 fn get(
135 &self,
136 shape: Shape,
137 name: &str,
138 _init: candle_nn::Init,
139 dtype: DType,
140 dev: &Device,
141 ) -> candle_core::Result<Tensor> {
142 let tensor = self.load_tensor(name, Some(&shape), dtype, dev)?;
143 if tensor.shape() != &shape {
144 Err(candle_core::Error::UnexpectedShape {
145 msg: format!("shape mismatch for {name}"),
146 expected: shape,
147 got: tensor.shape().clone(),
148 })?
149 }
150 Ok(tensor)
151 }
152
153 fn get_unchecked(&self, name: &str, dtype: DType, dev: &Device) -> candle_core::Result<Tensor> {
154 self.load_tensor(name, None, dtype, dev)
155 }
156
157 fn contains_tensor(&self, name: &str) -> bool {
158 if let Some((source_name, _)) = zimage_qkv_request(name) {
159 return self.resolve_stored_name(&source_name).is_some();
160 }
161 self.resolve_stored_name(name).is_some()
162 }
163}
164
165enum ZImageVaeTensorSource {
166 Mmap(candle_core::safetensors::MmapedSafetensors),
167 Cpu(Arc<HashMap<String, Tensor>>),
168}
169
170struct ZImageVaeSafetensorsBackend {
174 source: ZImageVaeTensorSource,
175 aliases: BTreeMap<String, String>,
176}
177
178impl ZImageVaeSafetensorsBackend {
179 fn new(st: candle_core::safetensors::MmapedSafetensors) -> Self {
180 let aliases = Self::aliases_from_names(st.tensors().into_iter().map(|(name, _)| name));
181 Self {
182 source: ZImageVaeTensorSource::Mmap(st),
183 aliases,
184 }
185 }
186
187 fn from_cpu_tensors(tensors: Arc<HashMap<String, Tensor>>) -> Self {
188 let aliases = Self::aliases_from_names(tensors.keys().cloned());
189 Self {
190 source: ZImageVaeTensorSource::Cpu(tensors),
191 aliases,
192 }
193 }
194
195 fn aliases_from_names(names: impl IntoIterator<Item = String>) -> BTreeMap<String, String> {
196 names
197 .into_iter()
198 .filter_map(|name| zimage_vae_diffusers_name(&name).map(|diffusers| (diffusers, name)))
199 .collect()
200 }
201
202 fn resolve_stored_name<'a>(&'a self, name: &'a str) -> Cow<'a, str> {
203 if self.contains_stored_tensor(name) {
204 return Cow::Borrowed(name);
205 }
206 self.aliases
207 .get(name)
208 .map(|source| Cow::Borrowed(source.as_str()))
209 .unwrap_or(Cow::Borrowed(name))
210 }
211
212 fn contains_stored_tensor(&self, name: &str) -> bool {
213 match &self.source {
214 ZImageVaeTensorSource::Mmap(st) => st.get(name).is_ok(),
215 ZImageVaeTensorSource::Cpu(tensors) => tensors.contains_key(name),
216 }
217 }
218
219 fn load_cast(&self, name: &str, dtype: DType, dev: &Device) -> candle_core::Result<Tensor> {
220 let stored_name = self.resolve_stored_name(name);
221 let tensor = match &self.source {
222 ZImageVaeTensorSource::Mmap(st) => st.load(stored_name.as_ref(), dev)?,
223 ZImageVaeTensorSource::Cpu(tensors) => tensors
224 .get(stored_name.as_ref())
225 .ok_or_else(|| {
226 candle_core::Error::msg(format!(
227 "missing Z-Image VAE tensor {}",
228 stored_name.as_ref()
229 ))
230 })?
231 .to_device(dev)?,
232 };
233 if tensor.dtype() != dtype {
234 tensor.to_dtype(dtype)
235 } else {
236 Ok(tensor)
237 }
238 }
239}
240
241impl candle_nn::var_builder::SimpleBackend for ZImageVaeSafetensorsBackend {
242 fn get(
243 &self,
244 shape: Shape,
245 name: &str,
246 _init: candle_nn::Init,
247 dtype: DType,
248 dev: &Device,
249 ) -> candle_core::Result<Tensor> {
250 let mut tensor = self.load_cast(name, dtype, dev)?;
251 if tensor.shape() != &shape
252 && tensor.dims().len() == 4
253 && shape.dims().len() == 2
254 && tensor.dims()[0] == shape.dims()[0]
255 && tensor.dims()[1] == shape.dims()[1]
256 && tensor.dims()[2] == 1
257 && tensor.dims()[3] == 1
258 {
259 tensor = tensor.reshape(shape.dims())?;
260 }
261 if tensor.shape() != &shape {
262 Err(candle_core::Error::UnexpectedShape {
263 msg: format!("shape mismatch for {name}"),
264 expected: shape,
265 got: tensor.shape().clone(),
266 })?
267 }
268 Ok(tensor)
269 }
270
271 fn get_unchecked(&self, name: &str, dtype: DType, dev: &Device) -> candle_core::Result<Tensor> {
272 self.load_cast(name, dtype, dev)
273 }
274
275 fn contains_tensor(&self, name: &str) -> bool {
276 self.contains_stored_tensor(name) || self.aliases.contains_key(name)
277 }
278}
279
280fn zimage_vae_diffusers_name(source_name: &str) -> Option<String> {
281 if source_name.starts_with("first_stage_model.") {
282 return crate::loader::vae_keys::apply_vae_rename(source_name);
283 }
284 if source_name.starts_with("encoder.")
285 || source_name.starts_with("decoder.")
286 || source_name.starts_with("quant_conv.")
287 || source_name.starts_with("post_quant_conv.")
288 {
289 return crate::loader::vae_keys::apply_vae_rename(&format!(
290 "first_stage_model.{source_name}"
291 ));
292 }
293 None
294}
295
296fn zimage_qkv_request(name: &str) -> Option<(String, usize)> {
297 for (suffix, component) in [
298 (".attention.to_q.weight", 0),
299 (".attention.to_k.weight", 1),
300 (".attention.to_v.weight", 2),
301 ] {
302 if let Some(prefix) = name.strip_suffix(suffix) {
303 return Some((format!("{prefix}.attention.qkv.weight"), component));
304 }
305 }
306 None
307}
308
309fn zimage_safetensors_alias(name: &str) -> Option<Cow<'_, str>> {
310 match name {
311 "all_x_embedder.2-1.weight" => return Some(Cow::Borrowed("x_embedder.weight")),
312 "all_x_embedder.2-1.bias" => return Some(Cow::Borrowed("x_embedder.bias")),
313 "all_final_layer.2-1.linear.weight" => {
314 return Some(Cow::Borrowed("final_layer.linear.weight"));
315 }
316 "all_final_layer.2-1.linear.bias" => {
317 return Some(Cow::Borrowed("final_layer.linear.bias"));
318 }
319 "all_final_layer.2-1.adaLN_modulation.1.weight" => {
320 return Some(Cow::Borrowed("final_layer.adaLN_modulation.1.weight"));
321 }
322 "all_final_layer.2-1.adaLN_modulation.1.bias" => {
323 return Some(Cow::Borrowed("final_layer.adaLN_modulation.1.bias"));
324 }
325 _ => {}
326 }
327 for (requested, stored) in [
328 (".attention.to_out.0.weight", ".attention.out.weight"),
329 (".attention.norm_q.weight", ".attention.q_norm.weight"),
330 (".attention.norm_k.weight", ".attention.k_norm.weight"),
331 ] {
332 if let Some(prefix) = name.strip_suffix(requested) {
333 return Some(Cow::Owned(format!("{prefix}{stored}")));
334 }
335 }
336 None
337}
338const BASE_SHIFT: f64 = 0.5;
339const MAX_SHIFT: f64 = 1.15;
340
341fn build_zimage_scheduler(
342 num_steps: usize,
343 image_seq_len: usize,
344 strength: Option<f64>,
345) -> (FlowMatchEulerDiscreteScheduler, usize) {
346 let mut scheduler = FlowMatchEulerDiscreteScheduler::new(SchedulerConfig::z_image_turbo());
347 let mu = calculate_shift(
348 image_seq_len,
349 BASE_IMAGE_SEQ_LEN,
350 MAX_IMAGE_SEQ_LEN,
351 BASE_SHIFT,
352 MAX_SHIFT,
353 );
354 let sigmas: Vec<f64> = (0..=num_steps)
355 .map(|v| v as f64 / num_steps as f64)
356 .rev()
357 .map(|t| {
358 if !(0.0..1.0).contains(&t) {
359 t
360 } else {
361 let e_mu = mu.exp();
362 e_mu / (e_mu + (1.0 / t - 1.0))
363 }
364 })
365 .collect();
366 scheduler.timesteps = sigmas[..sigmas.len().saturating_sub(1)]
367 .iter()
368 .map(|sigma| sigma * scheduler.config.num_train_timesteps as f64)
369 .collect();
370 scheduler.sigmas = sigmas;
371 let start_index = strength
372 .map(|strength| crate::img2img::img2img_start_index(num_steps, strength))
373 .unwrap_or(0);
374 if start_index > 0 {
375 scheduler.timesteps = scheduler.timesteps[start_index..].to_vec();
376 scheduler.sigmas = scheduler.sigmas[start_index..].to_vec();
377 }
378 scheduler.reset();
379 (scheduler, start_index)
380}
381
382fn load_zimage_vae(
383 path: &std::path::Path,
384 dtype: DType,
385 device: &Device,
386 progress: &ProgressReporter,
387 cached_tensors: Option<Arc<HashMap<String, Tensor>>>,
388) -> Result<AutoEncoderKL> {
389 use candle_core::safetensors::MmapedSafetensors;
390
391 let bytes_total = std::fs::metadata(path).map(|m| m.len()).unwrap_or(0);
392 progress.weight_load("VAE", 0, bytes_total);
393 let backend = if let Some(tensors) = cached_tensors {
394 ZImageVaeSafetensorsBackend::from_cpu_tensors(tensors)
395 } else {
396 let st = unsafe { MmapedSafetensors::multi(&[path])? };
397 ZImageVaeSafetensorsBackend::new(st)
398 };
399 let vae_vb = candle_nn::VarBuilder::from_backend(Box::new(backend), dtype, device.clone());
400 progress.weight_load("VAE", bytes_total, bytes_total);
401 AutoEncoderKL::new(&VaeConfig::z_image(), vae_vb).map_err(Into::into)
402}
403
404fn zimage_qwen3_preference<'a>(
405 configured: Option<&'a str>,
406 text_encoder_paths: &[std::path::PathBuf],
407) -> Option<&'a str> {
408 if configured.is_none() && zimage_has_recipe_text_encoder(text_encoder_paths) {
409 Some("bf16")
410 } else {
411 configured
412 }
413}
414
415fn zimage_has_recipe_text_encoder(text_encoder_paths: &[std::path::PathBuf]) -> bool {
416 text_encoder_paths.iter().any(|path| {
417 path.components()
418 .any(|component| component.as_os_str() == "civitai")
419 })
420}
421
422fn model_timestep(scheduler: &FlowMatchEulerDiscreteScheduler) -> f64 {
423 1.0 - scheduler.current_sigma()
424}
425
426fn zimage_debug_enabled() -> bool {
427 std::env::var_os("MOLD_ZIMAGE_DEBUG").is_some()
428}
429
430fn tensor_stats_summary(name: &str, tensor: &Tensor) -> Result<String> {
431 let flat = tensor.to_dtype(DType::F32)?.flatten_all()?;
432 let mean = flat.mean_all()?.to_scalar::<f32>()?;
433 let min = flat.min(0)?.to_scalar::<f32>()?;
434 let max = flat.max(0)?.to_scalar::<f32>()?;
435 let rms = flat.sqr()?.mean_all()?.to_scalar::<f32>()?.sqrt();
436 Ok(format!(
437 "{name}: mean={mean:.5} min={min:.5} max={max:.5} rms={rms:.5}"
438 ))
439}
440
441struct LoadedZImage {
443 transformer: Option<ZImageTransformer>,
446 text_encoder: encoders::qwen3::Qwen3Encoder,
447 vae: AutoEncoderKL,
448 transformer_cfg: Config,
449 device: Device,
451 vae_device: Device,
453 dtype: DType,
454 vae_dtype: DType,
458 is_gguf: bool,
460 vae_path: std::path::PathBuf,
462}
463
464pub struct ZImageEngine {
466 base: EngineBase<LoadedZImage>,
467 qwen3_variant: Option<String>,
469 offload: bool,
473 prompt_cache: Mutex<LruCache<String, CachedTensor>>,
474 pending_placement: Option<mold_core::types::DevicePlacement>,
476 pending_loras: Vec<LoraWeight>,
483 shared_pool: Option<Arc<Mutex<crate::shared_pool::SharedPool>>>,
484}
485
486pub(crate) fn effective_zimage_loras(req: &GenerateRequest) -> Vec<LoraWeight> {
490 const ZERO_SCALE_EPS: f64 = 1e-8;
492
493 let raw: Vec<LoraWeight> = if let Some(plural) = &req.loras {
494 if !plural.is_empty() {
495 plural.clone()
496 } else {
497 req.lora.iter().cloned().collect()
498 }
499 } else {
500 req.lora.iter().cloned().collect()
501 };
502 raw.into_iter()
503 .filter(|w| {
504 let keep = w.scale.abs() > ZERO_SCALE_EPS;
505 if !keep {
506 tracing::debug!(
507 path = w.path.as_str(),
508 scale = w.scale,
509 "dropping zero-scale Z-Image LoRA"
510 );
511 }
512 keep
513 })
514 .collect()
515}
516
517#[derive(Debug, PartialEq, Eq)]
518enum ZImageOffloadDecision {
519 Disabled,
520 Selected,
521 Unsupported(&'static str),
522}
523
524fn zimage_offload_decision(
525 forced_offload: bool,
526 is_gguf: bool,
527 has_lora: bool,
528) -> ZImageOffloadDecision {
529 if !forced_offload {
530 return ZImageOffloadDecision::Disabled;
531 }
532 if is_gguf {
533 return ZImageOffloadDecision::Unsupported(
534 "Z-Image block-level offload is only planned for BF16/FP transformers; \
535 GGUF variants already use quantized/dense GGUF-specific paths",
536 );
537 }
538 if has_lora {
539 return ZImageOffloadDecision::Unsupported(
540 "Z-Image block-level offload with LoRA is not wired yet; \
541 LoRA merge/bypass semantics need a dedicated offload design",
542 );
543 }
544 ZImageOffloadDecision::Selected
545}
546
547impl ZImageEngine {
548 pub fn new(
549 model_name: String,
550 paths: ModelPaths,
551 qwen3_variant: Option<String>,
552 load_strategy: LoadStrategy,
553 gpu_ordinal: usize,
554 offload: bool,
555 shared_pool: Option<Arc<Mutex<crate::shared_pool::SharedPool>>>,
556 ) -> Self {
557 Self {
558 base: EngineBase::new(model_name, paths, load_strategy, gpu_ordinal),
559 qwen3_variant,
560 offload,
561 prompt_cache: Mutex::new(LruCache::new(DEFAULT_PROMPT_CACHE_CAPACITY)),
562 pending_placement: None,
563 pending_loras: Vec::new(),
564 shared_pool,
565 }
566 }
567
568 fn load_text_tokenizer(&self, tokenizer_path: &Path) -> Result<Arc<Tokenizer>> {
569 if let Some(shared_pool) = &self.shared_pool {
570 return shared_pool.lock().unwrap().load_tokenizer(tokenizer_path);
571 }
572 Tokenizer::from_file(tokenizer_path)
573 .map(Arc::new)
574 .map_err(|e| anyhow::anyhow!("failed to load Qwen3 tokenizer: {e}"))
575 }
576
577 fn encode_prompt_cached(
578 progress: &ProgressReporter,
579 prompt_cache: &Mutex<LruCache<String, CachedTensor>>,
580 encoder: &mut encoders::qwen3::Qwen3Encoder,
581 prompt: &str,
582 device: &Device,
583 dtype: DType,
584 ) -> Result<(Tensor, Tensor)> {
585 let cache_key = prompt_text_key(prompt);
586 let (cap_feats, cache_hit) =
587 get_or_insert_cached_tensor(prompt_cache, cache_key, device, dtype, || {
588 progress.stage_start("Encoding prompt (Qwen3)");
589 let encode_start = Instant::now();
590 let (cap_feats, _token_count) = encoder.encode(prompt, device, dtype)?;
591 progress.stage_done("Encoding prompt (Qwen3)", encode_start.elapsed());
592 Ok(cap_feats)
593 })?;
594 if cache_hit {
595 progress.cache_hit("prompt conditioning");
596 }
597 let token_count = cap_feats.dim(1)?;
598 let cap_mask = Tensor::ones((1, token_count), DType::U8, device)?;
599 Ok((cap_feats, cap_mask))
600 }
601
602 fn transformer_paths(&self) -> Vec<std::path::PathBuf> {
605 if !self.base.paths.transformer_shards.is_empty() {
606 self.base.paths.transformer_shards.clone()
607 } else {
608 vec![self.base.paths.transformer.clone()]
609 }
610 }
611
612 fn detect_is_gguf(&self) -> bool {
614 self.base
615 .paths
616 .transformer
617 .extension()
618 .and_then(|e| e.to_str())
619 .map(|e| e.eq_ignore_ascii_case("gguf"))
620 .unwrap_or(false)
621 }
622
623 fn validate_paths(&self) -> Result<std::path::PathBuf> {
625 let text_tokenizer_path =
626 self.base.paths.text_tokenizer.as_ref().ok_or_else(|| {
627 anyhow::anyhow!("text tokenizer path required for Z-Image models")
628 })?;
629 if !text_tokenizer_path.exists() {
630 bail!(
631 "text tokenizer file not found: {}",
632 text_tokenizer_path.display()
633 );
634 }
635
636 let xformer_paths = self.transformer_paths();
637 for path in &xformer_paths {
638 if !path.exists() {
639 bail!("transformer file not found: {}", path.display());
640 }
641 }
642 if !self.base.paths.vae.exists() {
643 bail!("VAE file not found: {}", self.base.paths.vae.display());
644 }
645
646 Ok(text_tokenizer_path.clone())
647 }
648
649 fn load_transformer(
664 &self,
665 device: &Device,
666 dtype: DType,
667 cfg: &Config,
668 ) -> Result<ZImageTransformer> {
669 let is_gguf = self.detect_is_gguf();
670 let xformer_paths = self.transformer_paths();
671 let has_lora = !self.pending_loras.is_empty();
672
673 if is_gguf {
674 if has_lora {
675 let adapters =
676 super::lora::load_lora_adapters(&self.pending_loras, &self.base.progress)?;
677 let specs: Vec<super::lora::ZImageLoraSpec<'_>> = adapters
678 .iter()
679 .zip(self.pending_loras.iter())
680 .map(|(adapter, w)| super::lora::ZImageLoraSpec {
681 adapter: adapter.as_ref(),
682 scale: w.scale,
683 path_hash: super::lora::lora_path_hash(&w.path),
684 })
685 .collect();
686 let vb = super::lora::gguf_lora_var_builder(
687 &self.base.paths.transformer,
688 &specs,
689 device,
690 &self.base.progress,
691 )?;
692 return Ok(ZImageTransformer::Quantized(Box::new(
693 super::quantized_transformer::QuantizedZImageTransformer2DModel::new(
694 cfg, dtype, vb,
695 )?,
696 )));
697 }
698 let qvb =
699 quantized_var_builder::VarBuilder::from_gguf(&self.base.paths.transformer, device)?;
700 Ok(ZImageTransformer::Dense(Box::new(
701 load_gguf_dense_transformer(cfg, dtype, qvb)?,
702 )))
703 } else if has_lora {
704 use candle_core::safetensors::MmapedSafetensors;
709 let path_refs: Vec<&std::path::Path> =
710 xformer_paths.iter().map(|p| p.as_path()).collect();
711 let st = unsafe { MmapedSafetensors::multi(&path_refs)? };
712 let inner: Box<dyn candle_nn::var_builder::SimpleBackend> =
713 Box::new(ZImageSafetensorsBackend::new(st));
714 let adapters =
715 super::lora::load_lora_adapters(&self.pending_loras, &self.base.progress)?;
716 let specs: Vec<super::lora::ZImageLoraSpec<'_>> = adapters
717 .iter()
718 .zip(self.pending_loras.iter())
719 .map(|(adapter, w)| super::lora::ZImageLoraSpec {
720 adapter: adapter.as_ref(),
721 scale: w.scale,
722 path_hash: super::lora::lora_path_hash(&w.path),
723 })
724 .collect();
725 let wrapped =
726 super::lora::wrap_backend_with_lora(inner, &specs, &self.base.progress, None)?;
727 let vb = candle_nn::VarBuilder::from_backend(wrapped, dtype, device.clone());
728 Ok(ZImageTransformer::Dense(Box::new(
729 MoldZImageTransformer2DModel::new(cfg, vb)?,
730 )))
731 } else if self.offload {
732 use candle_core::safetensors::MmapedSafetensors;
733 let path_refs: Vec<&std::path::Path> =
734 xformer_paths.iter().map(|p| p.as_path()).collect();
735 let bytes_total: u64 = xformer_paths
736 .iter()
737 .map(|p| std::fs::metadata(p).map(|m| m.len()).unwrap_or(0))
738 .sum();
739 self.base
740 .progress
741 .weight_load("Z-Image transformer (offload stems)", 0, bytes_total);
742 let gpu_st = unsafe { MmapedSafetensors::multi(&path_refs)? };
743 let cpu_st = unsafe { MmapedSafetensors::multi(&path_refs)? };
744 let gpu_vb = candle_nn::VarBuilder::from_backend(
745 Box::new(ZImageSafetensorsBackend::new(gpu_st)),
746 dtype,
747 device.clone(),
748 );
749 let cpu_vb = candle_nn::VarBuilder::from_backend(
750 Box::new(ZImageSafetensorsBackend::new(cpu_st)),
751 dtype,
752 Device::Cpu,
753 );
754 self.base.progress.weight_load(
755 "Z-Image transformer (offload stems)",
756 bytes_total,
757 bytes_total,
758 );
759 Ok(ZImageTransformer::Offloaded(Box::new(
760 super::offload::OffloadedZImageTransformer::new(cfg, gpu_vb, cpu_vb)?,
761 )))
762 } else {
763 use candle_core::safetensors::MmapedSafetensors;
764 let path_refs: Vec<&std::path::Path> =
765 xformer_paths.iter().map(|p| p.as_path()).collect();
766 let bytes_total = xformer_paths
767 .iter()
768 .map(|p| std::fs::metadata(p).map(|m| m.len()).unwrap_or(0))
769 .sum();
770 self.base
771 .progress
772 .weight_load("Z-Image transformer", 0, bytes_total);
773 let st = unsafe { MmapedSafetensors::multi(&path_refs)? };
774 let xformer_vb = candle_nn::VarBuilder::from_backend(
775 Box::new(ZImageSafetensorsBackend::new(st)),
776 dtype,
777 device.clone(),
778 );
779 self.base
780 .progress
781 .weight_load("Z-Image transformer", bytes_total, bytes_total);
782 Ok(ZImageTransformer::Dense(Box::new(
783 MoldZImageTransformer2DModel::new(cfg, xformer_vb)?,
784 )))
785 }
786 }
787
788 fn load_vae(&self, device: &Device, dtype: DType) -> Result<AutoEncoderKL> {
790 let cached_tensors = self.load_vae_cpu_tensors()?;
791 load_zimage_vae(
792 self.base.paths.vae.as_path(),
793 dtype,
794 device,
795 &self.base.progress,
796 cached_tensors,
797 )
798 }
799
800 fn load_vae_cpu_tensors(&self) -> Result<Option<Arc<HashMap<String, Tensor>>>> {
801 let Some(shared_pool) = &self.shared_pool else {
802 return Ok(None);
803 };
804 shared_pool
805 .lock()
806 .unwrap()
807 .load_safetensors_cpu_tensors(std::slice::from_ref(&self.base.paths.vae))
808 }
809
810 pub fn load(&mut self) -> Result<()> {
816 if self.base.loaded.is_some() {
817 return Ok(());
818 }
819
820 if self.base.load_strategy == LoadStrategy::Sequential {
822 return Ok(());
823 }
824
825 tracing::info!(model = %self.base.model_name, "loading Z-Image model components...");
826
827 let is_gguf = self.detect_is_gguf();
828 let text_tokenizer_path = self.validate_paths()?;
829
830 let transformer_ref = effective_device_ref(
831 self.pending_placement.as_ref(),
832 |adv| Some(adv.transformer),
833 false,
834 );
835 let device = crate::device::resolve_device(Some(transformer_ref), || {
836 crate::device::create_device(self.base.gpu_ordinal, &self.base.progress)
837 })?;
838 let dtype = crate::engine::gpu_dtype(&device);
839 let transformer_cfg = Config::z_image_turbo();
840
841 let xformer_label = if is_gguf {
843 "Loading Z-Image transformer (GPU, GGUF -> dense)".to_string()
844 } else {
845 let xformer_paths = self.transformer_paths();
846 format!(
847 "Loading Z-Image transformer ({} shards)",
848 xformer_paths.len()
849 )
850 };
851 self.base.progress.stage_start(&xformer_label);
852 let xformer_start = Instant::now();
853
854 let transformer = self.load_transformer(&device, dtype, &transformer_cfg)?;
855
856 self.base
857 .progress
858 .stage_done(&xformer_label, xformer_start.elapsed());
859 tracing::info!(quantized = is_gguf, "Z-Image transformer loaded");
860
861 let free_raw = free_vram_bytes(self.base.gpu_ordinal).unwrap_or(0);
865 let free = usable_free_vram_bytes(self.base.gpu_ordinal).unwrap_or(0);
866 let is_cuda = device.is_cuda();
867 let is_metal = device.is_metal();
868 if free_raw > 0 {
869 self.base.progress.info(&format!(
870 "Free VRAM after transformer: {}",
871 fmt_gb(free_raw)
872 ));
873 tracing::info!(
874 free_vram = free_raw,
875 free_vram_usable = free,
876 "free VRAM after loading transformer"
877 );
878 }
879
880 let vae_on_gpu = should_use_gpu(is_cuda, is_metal, free, VAE_WEIGHT_LOAD_VRAM_THRESHOLD);
884 let vae_ref =
885 effective_device_ref(self.pending_placement.as_ref(), |adv| Some(adv.vae), false);
886 let vae_device = crate::device::resolve_device(Some(vae_ref), || {
887 Ok(if vae_on_gpu {
888 device.clone()
889 } else {
890 Device::Cpu
891 })
892 })?;
893 let vae_on_gpu = !vae_device.is_cpu();
894 let vae_dtype = if vae_on_gpu {
897 crate::device::resolve_vae_dtype(dtype)
898 } else {
899 DType::F32
900 };
901 let vae_device_label = if vae_on_gpu { "GPU" } else { "CPU" };
902
903 if !vae_on_gpu && (is_cuda || is_metal) {
904 self.base.progress.info(&format!(
905 "VAE on CPU ({} free < {} threshold for VAE weight load)",
906 fmt_gb(free),
907 fmt_gb(VAE_WEIGHT_LOAD_VRAM_THRESHOLD),
908 ));
909 }
910
911 let vae_label = format!("Loading VAE ({})", vae_device_label);
913 self.base.progress.stage_start(&vae_label);
914 let vae_start = Instant::now();
915 let vae = self.load_vae(&vae_device, vae_dtype)?;
916 self.base
917 .progress
918 .stage_done(&vae_label, vae_start.elapsed());
919 tracing::info!(device = vae_device_label, "Z-Image VAE loaded");
920
921 self.base.progress.stage_start("Selecting Qwen3 encoder");
923 let qwen3_resolve_start = Instant::now();
924 let qwen3_preference = zimage_qwen3_preference(
925 self.qwen3_variant.as_deref(),
926 &self.base.paths.text_encoder_files,
927 );
928 let (resolved_paths, is_qwen3_gguf, te_on_gpu, _te_auto_device_label) = {
929 let bf16_paths = self.base.paths.text_encoder_files.clone();
930 let have_bf16 = !bf16_paths.is_empty() && bf16_paths.iter().all(|p| p.exists());
931 crate::encoders::variant_resolution::resolve_qwen3_variant(
932 &self.base.progress,
933 qwen3_preference,
934 &device,
935 free,
936 &bf16_paths,
937 have_bf16,
938 false,
939 crate::encoders::variant_resolution::Qwen3Size::B4,
940 )?
941 };
942 self.base
943 .progress
944 .stage_done("Selecting Qwen3 encoder", qwen3_resolve_start.elapsed());
945
946 let qwen3_ref = effective_device_ref(self.pending_placement.as_ref(), |adv| adv.qwen, true);
947 let auto_te_device = if te_on_gpu {
948 device.clone()
949 } else {
950 Device::Cpu
951 };
952 let te_device =
953 crate::device::resolve_device(Some(qwen3_ref), || Ok(auto_te_device.clone()))?;
954 let te_on_gpu = !te_device.is_cpu();
955 let te_device_label = if te_on_gpu { "GPU" } else { "CPU" };
956 let te_dtype = if te_on_gpu { dtype } else { DType::F32 };
957
958 let bf16_cfg = encoders::qwen3_bf16::Qwen3BF16Config::qwen3_4b();
960 let te_label = if is_qwen3_gguf {
961 format!("Loading Qwen3 text encoder (GGUF, {})", te_device_label)
962 } else {
963 format!(
964 "Loading Qwen3 text encoder ({} shards, {})",
965 resolved_paths.len(),
966 te_device_label,
967 )
968 };
969 self.base.progress.stage_start(&te_label);
970 let te_start = Instant::now();
971 let text_tokenizer = self.load_text_tokenizer(&text_tokenizer_path)?;
972
973 let text_encoder = if is_qwen3_gguf {
974 encoders::qwen3::Qwen3Encoder::load_gguf_with_tokenizer(
975 &resolved_paths[0],
976 &text_tokenizer_path,
977 Some(text_tokenizer),
978 &te_device,
979 &bf16_cfg,
980 )?
981 } else {
982 encoders::qwen3::Qwen3Encoder::load_bf16_with_tokenizer(
983 &resolved_paths,
984 &text_tokenizer_path,
985 Some(text_tokenizer),
986 &te_device,
987 te_dtype,
988 &bf16_cfg,
989 &self.base.progress,
990 )?
991 };
992
993 self.base.progress.stage_done(&te_label, te_start.elapsed());
994 tracing::info!(device = %te_device_label, quantized = is_qwen3_gguf, "Qwen3 text encoder loaded");
995
996 self.base.loaded = Some(LoadedZImage {
997 transformer: Some(transformer),
998 text_encoder,
999 vae,
1000 transformer_cfg,
1001 device,
1002 vae_device,
1003 dtype,
1004 vae_dtype,
1005 is_gguf,
1006 vae_path: self.base.paths.vae.clone(),
1007 });
1008
1009 tracing::info!(model = %self.base.model_name, "all Z-Image components loaded successfully");
1010 Ok(())
1011 }
1012
1013 fn reload_transformer(&self, loaded: &mut LoadedZImage) -> Result<()> {
1015 let transformer =
1016 self.load_transformer(&loaded.device, loaded.dtype, &loaded.transformer_cfg)?;
1017 loaded.transformer = Some(transformer);
1018 Ok(())
1019 }
1020
1021 fn uses_sequential_generate_path(&self) -> bool {
1022 self.base.load_strategy == LoadStrategy::Sequential
1023 || self.offload
1024 || !self.pending_loras.is_empty()
1025 }
1026
1027 fn generate_sequential(&mut self, req: &GenerateRequest) -> Result<GenerateResponse> {
1036 let text_tokenizer_path = self.validate_paths()?;
1037 let is_gguf = self.detect_is_gguf();
1038 let transformer_cfg = Config::z_image_turbo();
1039
1040 match zimage_offload_decision(self.offload, is_gguf, !self.pending_loras.is_empty()) {
1041 ZImageOffloadDecision::Disabled => {}
1042 ZImageOffloadDecision::Unsupported(reason) => bail!("{reason}"),
1043 ZImageOffloadDecision::Selected => {}
1044 }
1045
1046 if let Some(warning) = check_memory_budget(&self.base.paths, LoadStrategy::Sequential) {
1048 self.base.progress.info(&warning);
1049 }
1050
1051 let transformer_ref = effective_device_ref(
1052 self.pending_placement.as_ref(),
1053 |adv| Some(adv.transformer),
1054 false,
1055 );
1056 let device = crate::device::resolve_device(Some(transformer_ref), || {
1057 crate::device::create_device(self.base.gpu_ordinal, &self.base.progress)
1058 })?;
1059 let dtype = crate::engine::gpu_dtype(&device);
1060
1061 let start = Instant::now();
1062 let seed = req.seed.unwrap_or_else(rand_seed);
1063
1064 let width = req.width as usize;
1065 let height = req.height as usize;
1066
1067 tracing::info!(
1068 prompt = %req.prompt,
1069 seed, width, height,
1070 steps = req.steps,
1071 "starting sequential Z-Image generation"
1072 );
1073
1074 self.base
1075 .progress
1076 .info("Using sequential loading (load-use-drop) to minimize peak memory");
1077
1078 let cache_key = prompt_text_key(&req.prompt);
1080 let (cap_feats, cap_mask) = if let Some(cap_feats) =
1081 restore_cached_tensor(&self.prompt_cache, &cache_key, &device, dtype)?
1082 {
1083 self.base.progress.cache_hit("prompt conditioning");
1084 let token_count = cap_feats.dim(1)?;
1085 let cap_mask = Tensor::ones((1, token_count), DType::U8, &device)?;
1086 (cap_feats, cap_mask)
1087 } else {
1088 let free = usable_free_vram_bytes(self.base.gpu_ordinal).unwrap_or(0);
1090 self.base.progress.stage_start("Selecting Qwen3 encoder");
1091 let qwen3_resolve_start = Instant::now();
1092 let qwen3_preference = zimage_qwen3_preference(
1093 self.qwen3_variant.as_deref(),
1094 &self.base.paths.text_encoder_files,
1095 );
1096 let (resolved_paths, is_qwen3_gguf, te_on_gpu, _te_auto_device_label) = {
1097 let bf16_paths = self.base.paths.text_encoder_files.clone();
1098 let have_bf16 = !bf16_paths.is_empty() && bf16_paths.iter().all(|p| p.exists());
1099 crate::encoders::variant_resolution::resolve_qwen3_variant(
1100 &self.base.progress,
1101 qwen3_preference,
1102 &device,
1103 free,
1104 &bf16_paths,
1105 have_bf16,
1106 false,
1107 crate::encoders::variant_resolution::Qwen3Size::B4,
1108 )?
1109 };
1110 self.base
1111 .progress
1112 .stage_done("Selecting Qwen3 encoder", qwen3_resolve_start.elapsed());
1113
1114 let qwen3_ref =
1115 effective_device_ref(self.pending_placement.as_ref(), |adv| adv.qwen, true);
1116 let auto_te_device = if te_on_gpu {
1117 device.clone()
1118 } else {
1119 Device::Cpu
1120 };
1121 let te_device =
1122 crate::device::resolve_device(Some(qwen3_ref), || Ok(auto_te_device.clone()))?;
1123 let te_on_gpu = !te_device.is_cpu();
1124 let te_device_label = if te_on_gpu { "GPU" } else { "CPU" };
1125 let te_dtype = if te_on_gpu { dtype } else { DType::F32 };
1126
1127 let bf16_cfg = encoders::qwen3_bf16::Qwen3BF16Config::qwen3_4b();
1128 let te_label = if is_qwen3_gguf {
1129 format!("Loading Qwen3 text encoder (GGUF, {})", te_device_label)
1130 } else {
1131 format!(
1132 "Loading Qwen3 text encoder ({} shards, {})",
1133 resolved_paths.len(),
1134 te_device_label,
1135 )
1136 };
1137 let te_size: u64 = resolved_paths
1138 .iter()
1139 .filter_map(|p| std::fs::metadata(p).ok())
1140 .map(|m| m.len())
1141 .sum();
1142 let te_activation_budget = crate::device::activation_bytes(
1143 req.width,
1144 req.height,
1145 1,
1146 crate::device::dtype_bytes(te_dtype),
1147 crate::device::ActivationFamily::SmallTransformer,
1148 );
1149 preflight_memory_check("Qwen3 text encoder", te_size, te_activation_budget)?;
1150
1151 if let Some(status) = memory_status_string() {
1152 self.base.progress.info(&status);
1153 }
1154
1155 self.base.progress.stage_start(&te_label);
1156 let te_start = Instant::now();
1157 let text_tokenizer = self.load_text_tokenizer(&text_tokenizer_path)?;
1158
1159 let mut text_encoder = if is_qwen3_gguf {
1160 encoders::qwen3::Qwen3Encoder::load_gguf_with_tokenizer(
1161 &resolved_paths[0],
1162 &text_tokenizer_path,
1163 Some(text_tokenizer),
1164 &te_device,
1165 &bf16_cfg,
1166 )?
1167 } else {
1168 encoders::qwen3::Qwen3Encoder::load_bf16_with_tokenizer(
1169 &resolved_paths,
1170 &text_tokenizer_path,
1171 Some(text_tokenizer),
1172 &te_device,
1173 te_dtype,
1174 &bf16_cfg,
1175 &self.base.progress,
1176 )?
1177 };
1178 self.base.progress.stage_done(&te_label, te_start.elapsed());
1179
1180 let (cap_feats, cap_mask) = Self::encode_prompt_cached(
1181 &self.base.progress,
1182 &self.prompt_cache,
1183 &mut text_encoder,
1184 &req.prompt,
1185 &device,
1186 dtype,
1187 )?;
1188
1189 drop(text_encoder);
1190 self.base.progress.info("Freed Qwen3 text encoder");
1191 tracing::info!("Qwen3 text encoder dropped (sequential mode)");
1192
1193 (cap_feats, cap_mask)
1194 };
1195
1196 let vae_align = 16;
1200 let latent_h = 2 * (height / vae_align);
1201 let latent_w = 2 * (width / vae_align);
1202
1203 let patch_size = transformer_cfg.all_patch_size[0];
1204 let image_seq_len = (latent_h / patch_size) * (latent_w / patch_size);
1205 let (mut scheduler, start_index) = build_zimage_scheduler(
1206 req.steps as usize,
1207 image_seq_len,
1208 req.source_image.as_ref().map(|_| req.strength),
1209 );
1210
1211 if req.source_image.is_some() {
1212 tracing::info!(
1213 strength = req.strength,
1214 start_index,
1215 start_sigma = scheduler.sigmas[0],
1216 remaining_sigmas = scheduler.sigmas.len(),
1217 remaining_steps = scheduler.sigmas.len().saturating_sub(1),
1218 "img2img: truncated schedule from strength"
1219 );
1220 }
1221
1222 let (mut latents, inpaint_ctx) = if let Some(ref source_bytes) = req.source_image {
1224 let start_sigma = scheduler.sigmas[0];
1225
1226 let encode_vae_device = if device.is_cuda() || device.is_metal() {
1228 device.clone()
1229 } else {
1230 Device::Cpu
1231 };
1232 let encode_vae_dtype = if encode_vae_device.is_cpu() {
1233 DType::F32
1234 } else {
1235 crate::device::resolve_vae_dtype(dtype)
1236 };
1237 let encode_label = if encode_vae_device.is_cpu() {
1238 "Loading VAE for source encoding (CPU)"
1239 } else {
1240 "Loading VAE for source encoding (GPU)"
1241 };
1242
1243 self.base.progress.stage_start(encode_label);
1244 let vae_enc_start = Instant::now();
1245 let encode_vae = self.load_vae(&encode_vae_device, encode_vae_dtype)?;
1246 self.base
1247 .progress
1248 .stage_done(encode_label, vae_enc_start.elapsed());
1249
1250 self.base
1251 .progress
1252 .stage_start("Encoding source image (VAE)");
1253 let encode_start = Instant::now();
1254 let source_tensor = img_utils::decode_source_image(
1255 source_bytes,
1256 req.width,
1257 req.height,
1258 img_utils::NormalizeRange::ZeroToOne,
1259 &encode_vae_device,
1260 encode_vae_dtype,
1261 )?;
1262 let encoded = encode_vae.encode(&source_tensor)?;
1263 self.base
1264 .progress
1265 .stage_done("Encoding source image (VAE)", encode_start.elapsed());
1266
1267 drop(encode_vae);
1269
1270 let encoded = encoded.to_dtype(dtype)?.to_device(&device)?;
1272 let prepared = crate::img2img::prepare_flow_match_img2img(
1273 &encoded,
1274 seed,
1275 &[1, 16, latent_h, latent_w],
1276 start_sigma,
1277 req.mask_image.as_deref(),
1278 latent_h,
1279 latent_w,
1280 &device,
1281 dtype,
1282 )?;
1283 (prepared.initial_latents.unsqueeze(2)?, prepared.inpaint_ctx)
1285 } else {
1286 let noise =
1288 crate::engine::seeded_randn(seed, &[1, 16, latent_h, latent_w], &device, dtype)?;
1289 (noise.unsqueeze(2)?, None)
1290 };
1291
1292 let xformer_paths = self.transformer_paths();
1294 let xformer_size: u64 = xformer_paths
1295 .iter()
1296 .filter_map(|p| std::fs::metadata(p).ok())
1297 .map(|m| m.len())
1298 .sum();
1299 let xformer_activation_budget = crate::device::activation_bytes(
1300 req.width,
1301 req.height,
1302 1,
1303 crate::device::dtype_bytes(dtype),
1304 crate::device::ActivationFamily::ZImageDit,
1305 );
1306 preflight_memory_check(
1307 "Z-Image transformer",
1308 xformer_size,
1309 xformer_activation_budget,
1310 )?;
1311
1312 if let Some(status) = memory_status_string() {
1313 self.base.progress.info(&status);
1314 }
1315
1316 let xformer_label = if is_gguf {
1317 "Loading Z-Image transformer (GPU, GGUF -> dense)".to_string()
1318 } else {
1319 format!(
1320 "Loading Z-Image transformer ({} shards)",
1321 xformer_paths.len()
1322 )
1323 };
1324 self.base.progress.stage_start(&xformer_label);
1325 let xformer_start = Instant::now();
1326 let transformer = self.load_transformer(&device, dtype, &transformer_cfg)?;
1327 self.base
1328 .progress
1329 .stage_done(&xformer_label, xformer_start.elapsed());
1330
1331 let num_steps = scheduler.sigmas.len().saturating_sub(1);
1332 let denoise_label = format!("Denoising ({} steps)", num_steps);
1333 self.base.progress.stage_start(&denoise_label);
1334 let denoise_start = Instant::now();
1335
1336 for step in 0..num_steps {
1337 let step_start = Instant::now();
1338 let t = model_timestep(&scheduler);
1339 let t_tensor = Tensor::from_vec(vec![t as f32], (1,), &device)?.to_dtype(dtype)?;
1340 if zimage_debug_enabled() {
1341 tracing::debug!(
1342 step = step + 1,
1343 total = num_steps,
1344 sigma = scheduler.current_sigma(),
1345 timestep = t,
1346 "{}",
1347 tensor_stats_summary("latents_in", &latents)?
1348 );
1349 }
1350 let noise_pred = transformer.forward(&latents, &t_tensor, &cap_feats, &cap_mask)?;
1351 if zimage_debug_enabled() {
1352 tracing::debug!(
1353 step = step + 1,
1354 total = num_steps,
1355 "{}",
1356 tensor_stats_summary("noise_pred_raw", &noise_pred)?
1357 );
1358 }
1359 let noise_pred = noise_pred.neg()?;
1360 let noise_pred_4d = noise_pred.squeeze(2)?;
1361 let latents_4d = latents.squeeze(2)?;
1362 let prev_latents = scheduler.step(&noise_pred_4d, &latents_4d)?;
1363 latents = prev_latents.unsqueeze(2)?;
1364 if zimage_debug_enabled() {
1365 tracing::debug!(
1366 step = step + 1,
1367 total = num_steps,
1368 sigma_next = scheduler.current_sigma(),
1369 "{}",
1370 tensor_stats_summary("latents_out", &latents)?
1371 );
1372 }
1373
1374 if let Some(ref ctx) = inpaint_ctx {
1376 let latents_4d = latents.squeeze(2)?;
1377 let blended = crate::img2img::apply_flow_match_inpaint(
1378 &latents_4d,
1379 ctx,
1380 scheduler.sigmas[step + 1],
1381 )?;
1382 latents = blended.unsqueeze(2)?;
1383 }
1384
1385 self.base.progress.emit(ProgressEvent::DenoiseStep {
1386 step: step + 1,
1387 total: num_steps,
1388 elapsed: step_start.elapsed(),
1389 });
1390 }
1391
1392 self.base
1393 .progress
1394 .stage_done(&denoise_label, denoise_start.elapsed());
1395
1396 drop(transformer);
1398 self.base.progress.info("Freed Z-Image transformer");
1399 drop(cap_feats);
1400 drop(cap_mask);
1401 device.synchronize()?;
1402 tracing::info!("Transformer dropped (sequential mode)");
1403
1404 if let Some(status) = memory_status_string() {
1406 self.base.progress.info(&status);
1407 }
1408 let free_for_vae = usable_free_vram_bytes(self.base.gpu_ordinal).unwrap_or(0);
1411 let vae_on_gpu = should_use_gpu(
1412 device.is_cuda(),
1413 device.is_metal(),
1414 free_for_vae,
1415 VAE_DECODE_VRAM_THRESHOLD,
1416 );
1417 let vae_ref =
1418 effective_device_ref(self.pending_placement.as_ref(), |adv| Some(adv.vae), false);
1419 let vae_device = crate::device::resolve_device(Some(vae_ref), || {
1420 Ok(if vae_on_gpu {
1421 device.clone()
1422 } else {
1423 Device::Cpu
1424 })
1425 })?;
1426 let vae_on_gpu = !vae_device.is_cpu();
1427 let vae_dtype = if vae_on_gpu {
1430 crate::device::resolve_vae_dtype(dtype)
1431 } else {
1432 DType::F32
1433 };
1434 let vae_device_label = if vae_on_gpu { "GPU" } else { "CPU" };
1435
1436 let vae_label = format!("Loading VAE ({})", vae_device_label);
1437 self.base.progress.stage_start(&vae_label);
1438 let vae_start = Instant::now();
1439 let vae = self.load_vae(&vae_device, vae_dtype)?;
1440 self.base
1441 .progress
1442 .stage_done(&vae_label, vae_start.elapsed());
1443
1444 self.base.progress.stage_start("VAE decode");
1445 let vae_decode_start = Instant::now();
1446
1447 let latents = latents
1448 .squeeze(2)?
1449 .to_device(&vae_device)?
1450 .to_dtype(vae_dtype)?;
1451 let image = vae.decode(&latents)?;
1452 let image = postprocess_image(&image)?;
1453 let image = image.i(0)?;
1454
1455 self.base
1456 .progress
1457 .stage_done("VAE decode", vae_decode_start.elapsed());
1458
1459 let output_metadata = build_output_metadata(req, seed, None);
1461 let image_bytes = encode_image(
1462 &image,
1463 req.resolved_output_format(),
1464 req.width,
1465 req.height,
1466 output_metadata.as_ref(),
1467 )?;
1468
1469 let generation_time_ms = start.elapsed().as_millis() as u64;
1470 tracing::info!(
1471 generation_time_ms,
1472 seed,
1473 "sequential Z-Image generation complete"
1474 );
1475
1476 Ok(GenerateResponse {
1477 images: vec![ImageData {
1478 data: image_bytes,
1479 format: req.resolved_output_format(),
1480 width: req.width,
1481 height: req.height,
1482 index: 0,
1483 }],
1484 generation_time_ms,
1485 model: req.model.clone(),
1486 seed_used: seed,
1487 video: None,
1488 gpu: None,
1489 })
1490 }
1491}
1492
1493impl ZImageEngine {
1494 fn generate_inner(&mut self, req: &GenerateRequest) -> Result<GenerateResponse> {
1495 if req.scheduler.is_some() {
1496 tracing::warn!(
1497 "scheduler selection not supported for Z-Image (flow-matching), ignoring"
1498 );
1499 }
1500 if self.uses_sequential_generate_path() {
1505 self.base.unload();
1506 return self.generate_sequential(req);
1507 }
1508
1509 if self.base.loaded.is_none() {
1511 self.load()?;
1512 }
1513 if self.base.loaded.is_none() {
1514 bail!("model not loaded — call load() first");
1515 }
1516
1517 let progress = &self.base.progress;
1519
1520 let start = Instant::now();
1521
1522 let loaded_ref = self
1524 .base
1525 .loaded
1526 .as_ref()
1527 .ok_or_else(|| anyhow::anyhow!("model not loaded — call load() first"))?;
1528 let needs_reload = loaded_ref.transformer.is_none();
1529 if needs_reload {
1530 {
1531 let mut loaded_mut = self
1532 .base
1533 .loaded
1534 .take()
1535 .ok_or_else(|| anyhow::anyhow!("model not loaded — call load() first"))?;
1536 let xformer_label = if loaded_mut.is_gguf {
1537 "Reloading Z-Image transformer (GPU, GGUF -> dense)"
1538 } else {
1539 "Reloading Z-Image transformer (GPU, BF16)"
1540 };
1541 progress.stage_start(xformer_label);
1542 let reload_start = Instant::now();
1543 self.reload_transformer(&mut loaded_mut)?;
1544 progress.stage_done(xformer_label, reload_start.elapsed());
1545 self.base.loaded = Some(loaded_mut);
1546 }
1547 }
1548
1549 let loaded = self
1550 .base
1551 .loaded
1552 .as_mut()
1553 .ok_or_else(|| anyhow::anyhow!("model not loaded — call load() first"))?;
1554 let seed = req.seed.unwrap_or_else(rand_seed);
1555
1556 let width = req.width as usize;
1557 let height = req.height as usize;
1558
1559 tracing::info!(
1560 prompt = %req.prompt,
1561 seed, width, height,
1562 steps = req.steps,
1563 "starting Z-Image generation"
1564 );
1565
1566 let cache_key = prompt_text_key(&req.prompt);
1568 let (cap_feats, cap_mask) = if let Some(cap_feats) =
1569 restore_cached_tensor(&self.prompt_cache, &cache_key, &loaded.device, loaded.dtype)?
1570 {
1571 progress.cache_hit("prompt conditioning");
1572 let token_count = cap_feats.dim(1)?;
1573 let cap_mask = Tensor::ones((1, token_count), DType::U8, &loaded.device)?;
1574 (cap_feats, cap_mask)
1575 } else {
1576 if loaded.text_encoder.model.is_none() {
1580 let te_label = if loaded.text_encoder.is_parked() {
1581 "Unparking Qwen3 encoder (CPU→GPU)"
1582 } else if loaded.text_encoder.is_quantized {
1583 "Reloading Qwen3 encoder (GGUF)"
1584 } else {
1585 "Reloading Qwen3 encoder (BF16)"
1586 };
1587 progress.stage_start(te_label);
1588 let reload_start = Instant::now();
1589 if loaded.text_encoder.is_parked() {
1590 loaded.text_encoder.unpark_to_gpu(progress)?;
1591 } else {
1592 loaded.text_encoder.reload(progress)?;
1593 }
1594 progress.stage_done(te_label, reload_start.elapsed());
1595 }
1596
1597 let (cap_feats, cap_mask) = Self::encode_prompt_cached(
1598 progress,
1599 &self.prompt_cache,
1600 &mut loaded.text_encoder,
1601 &req.prompt,
1602 &loaded.device,
1603 loaded.dtype,
1604 )?;
1605 tracing::info!(token_count = cap_feats.dim(1)?, "text encoding complete");
1606
1607 if loaded.text_encoder.on_gpu || loaded.device.is_metal() {
1613 let park_mode = crate::device::keep_te_in_ram()
1614 && !loaded.device.is_metal()
1615 && !loaded.text_encoder.is_quantized;
1616 if park_mode {
1617 loaded.text_encoder.park_to_cpu()?;
1618 tracing::info!(
1619 on_gpu = loaded.text_encoder.on_gpu,
1620 "Qwen3 text encoder parked to CPU host RAM"
1621 );
1622 } else {
1623 loaded.text_encoder.drop_weights();
1624 tracing::info!(
1625 on_gpu = loaded.text_encoder.on_gpu,
1626 "Qwen3 text encoder dropped to free memory for denoising"
1627 );
1628 }
1629 }
1630
1631 (cap_feats, cap_mask)
1632 };
1633
1634 let vae_align = 16;
1636 let latent_h = 2 * (height / vae_align);
1637 let latent_w = 2 * (width / vae_align);
1638
1639 let patch_size = loaded.transformer_cfg.all_patch_size[0];
1641 let image_seq_len = (latent_h / patch_size) * (latent_w / patch_size);
1642 let (mut scheduler, start_index) = build_zimage_scheduler(
1643 req.steps as usize,
1644 image_seq_len,
1645 req.source_image.as_ref().map(|_| req.strength),
1646 );
1647
1648 if req.source_image.is_some() {
1649 tracing::info!(
1650 strength = req.strength,
1651 start_index,
1652 start_sigma = scheduler.sigmas[0],
1653 remaining_sigmas = scheduler.sigmas.len(),
1654 remaining_steps = scheduler.sigmas.len().saturating_sub(1),
1655 "img2img: truncated schedule from strength"
1656 );
1657 }
1658
1659 let (mut latents, inpaint_ctx) = if let Some(ref source_bytes) = req.source_image {
1661 let start_sigma = scheduler.sigmas[0];
1662
1663 progress.stage_start("Encoding source image (VAE)");
1665 let encode_start = Instant::now();
1666 let vae_encode_device = &loaded.vae_device;
1667 let vae_encode_dtype = if loaded.vae_device.is_cpu() {
1668 DType::F32
1669 } else {
1670 loaded.dtype
1671 };
1672 let source_tensor = img_utils::decode_source_image(
1673 source_bytes,
1674 req.width,
1675 req.height,
1676 img_utils::NormalizeRange::ZeroToOne,
1677 vae_encode_device,
1678 vae_encode_dtype,
1679 )?;
1680 let encoded = loaded.vae.encode(&source_tensor)?;
1681 progress.stage_done("Encoding source image (VAE)", encode_start.elapsed());
1682
1683 let encoded = encoded.to_dtype(loaded.dtype)?.to_device(&loaded.device)?;
1684
1685 let prepared = crate::img2img::prepare_flow_match_img2img(
1686 &encoded,
1687 seed,
1688 &[1, 16, latent_h, latent_w],
1689 start_sigma,
1690 req.mask_image.as_deref(),
1691 latent_h,
1692 latent_w,
1693 &loaded.device,
1694 loaded.dtype,
1695 )?;
1696 (prepared.initial_latents.unsqueeze(2)?, prepared.inpaint_ctx)
1697 } else {
1698 let noise = crate::engine::seeded_randn(
1700 seed,
1701 &[1, 16, latent_h, latent_w],
1702 &loaded.device,
1703 loaded.dtype,
1704 )?;
1705 (noise.unsqueeze(2)?, None)
1706 };
1707
1708 let num_steps = scheduler.sigmas.len().saturating_sub(1);
1710 let denoise_label = format!("Denoising ({} steps)", num_steps);
1711 progress.stage_start(&denoise_label);
1712 let denoise_start = Instant::now();
1713
1714 {
1716 let transformer = loaded
1717 .transformer
1718 .as_ref()
1719 .expect("transformer must be loaded for denoising");
1720
1721 for step in 0..num_steps {
1722 let step_start = Instant::now();
1723 let t = model_timestep(&scheduler);
1724 let t_tensor = Tensor::from_vec(vec![t as f32], (1,), &loaded.device)?
1725 .to_dtype(loaded.dtype)?;
1726 if zimage_debug_enabled() {
1727 tracing::debug!(
1728 step = step + 1,
1729 total = num_steps,
1730 sigma = scheduler.current_sigma(),
1731 timestep = t,
1732 "{}",
1733 tensor_stats_summary("latents_in", &latents)?
1734 );
1735 }
1736
1737 let noise_pred = transformer.forward(&latents, &t_tensor, &cap_feats, &cap_mask)?;
1739 if zimage_debug_enabled() {
1740 tracing::debug!(
1741 step = step + 1,
1742 total = num_steps,
1743 "{}",
1744 tensor_stats_summary("noise_pred_raw", &noise_pred)?
1745 );
1746 }
1747
1748 let noise_pred = noise_pred.neg()?;
1750
1751 let noise_pred_4d = noise_pred.squeeze(2)?;
1753 let latents_4d = latents.squeeze(2)?;
1754
1755 let prev_latents = scheduler.step(&noise_pred_4d, &latents_4d)?;
1757
1758 latents = prev_latents.unsqueeze(2)?;
1760 if zimage_debug_enabled() {
1761 tracing::debug!(
1762 step = step + 1,
1763 total = num_steps,
1764 sigma_next = scheduler.current_sigma(),
1765 "{}",
1766 tensor_stats_summary("latents_out", &latents)?
1767 );
1768 }
1769
1770 if let Some(ref ctx) = inpaint_ctx {
1772 let latents_4d = latents.squeeze(2)?;
1773 let blended = crate::img2img::apply_flow_match_inpaint(
1774 &latents_4d,
1775 ctx,
1776 scheduler.sigmas[step + 1],
1777 )?;
1778 latents = blended.unsqueeze(2)?;
1779 }
1780
1781 progress.emit(ProgressEvent::DenoiseStep {
1782 step: step + 1,
1783 total: num_steps,
1784 elapsed: step_start.elapsed(),
1785 });
1786 }
1787 }
1788
1789 progress.stage_done(&denoise_label, denoise_start.elapsed());
1790 tracing::info!("denoising complete");
1791
1792 drop(cap_feats);
1794 drop(cap_mask);
1795
1796 loaded.transformer = None;
1800 loaded.device.synchronize()?;
1803 tracing::info!("Z-Image transformer dropped from GPU to free VRAM for VAE decode");
1804
1805 progress.stage_start("VAE decode");
1807 let vae_start = Instant::now();
1808
1809 let latents_4d = latents.squeeze(2)?;
1811
1812 let image = {
1814 let decode_latents = latents_4d.to_device(&loaded.vae_device)?.to_dtype(
1815 if loaded.vae_device.is_cpu() {
1816 DType::F32
1817 } else {
1818 loaded.vae_dtype
1819 },
1820 )?;
1821 match loaded.vae.decode(&decode_latents) {
1822 Ok(img) => img,
1823 Err(e) if loaded.vae_device.is_cuda() => {
1824 let err_msg = format!("{e}");
1826 if err_msg.contains("OUT_OF_MEMORY") || err_msg.contains("out of memory") {
1827 tracing::warn!("VAE decode OOM on GPU, falling back to CPU");
1828 progress.info("VAE decode OOM on GPU — retrying on CPU");
1829 loaded.device.synchronize()?;
1830 let cpu_vae = load_zimage_vae(
1832 loaded.vae_path.as_path(),
1833 DType::F32,
1834 &Device::Cpu,
1835 progress,
1836 None,
1837 )?;
1838 let cpu_latents =
1839 latents_4d.to_device(&Device::Cpu)?.to_dtype(DType::F32)?;
1840 cpu_vae.decode(&cpu_latents)?
1841 } else {
1842 return Err(e.into());
1843 }
1844 }
1845 Err(e) => return Err(e.into()),
1846 }
1847 };
1848
1849 let image = postprocess_image(&image)?;
1851 let image = image.i(0)?; progress.stage_done("VAE decode", vae_start.elapsed());
1854
1855 let output_metadata = build_output_metadata(req, seed, None);
1857 let image_bytes = encode_image(
1858 &image,
1859 req.resolved_output_format(),
1860 req.width,
1861 req.height,
1862 output_metadata.as_ref(),
1863 )?;
1864
1865 let generation_time_ms = start.elapsed().as_millis() as u64;
1866 tracing::info!(generation_time_ms, seed, "Z-Image generation complete");
1867
1868 Ok(GenerateResponse {
1869 images: vec![ImageData {
1870 data: image_bytes,
1871 format: req.resolved_output_format(),
1872 width: req.width,
1873 height: req.height,
1874 index: 0,
1875 }],
1876 generation_time_ms,
1877 model: req.model.clone(),
1878 seed_used: seed,
1879 video: None,
1880 gpu: None,
1881 })
1882 }
1883}
1884
1885impl InferenceEngine for ZImageEngine {
1886 fn generate(&mut self, req: &GenerateRequest) -> Result<GenerateResponse> {
1887 self.pending_placement = req.placement.clone();
1888 self.pending_loras = effective_zimage_loras(req);
1889 let result = self.generate_inner(req);
1890 self.pending_placement = None;
1891 self.pending_loras.clear();
1892 result
1893 }
1894
1895 fn model_name(&self) -> &str {
1896 self.base.model_name()
1897 }
1898
1899 fn is_loaded(&self) -> bool {
1900 self.base.is_loaded()
1902 }
1903
1904 fn load(&mut self) -> Result<()> {
1905 ZImageEngine::load(self)
1906 }
1907
1908 fn unload(&mut self) {
1909 self.base.unload();
1910 clear_cache(&self.prompt_cache);
1911 }
1912
1913 fn set_on_progress(&mut self, callback: ProgressCallback) {
1914 self.base.set_on_progress(callback);
1915 }
1916
1917 fn clear_on_progress(&mut self) {
1918 self.base.clear_on_progress();
1919 }
1920
1921 fn model_paths(&self) -> Option<&mold_core::ModelPaths> {
1922 Some(&self.base.paths)
1923 }
1924}
1925
1926#[cfg(test)]
1927mod tests {
1928 use super::*;
1929 use crate::device::should_use_gpu;
1930 use crate::engine::LoadStrategy;
1931 use crate::shared_pool::SharedPool;
1932 use mold_core::ModelPaths;
1933 use std::fs;
1934 use std::path::{Path, PathBuf};
1935 use std::sync::{Arc, Mutex};
1936 use std::time::{SystemTime, UNIX_EPOCH};
1937 use tokenizers::models::bpe::BPE;
1938
1939 fn temp_test_dir(prefix: &str) -> PathBuf {
1940 let suffix = SystemTime::now()
1941 .duration_since(UNIX_EPOCH)
1942 .unwrap()
1943 .as_nanos();
1944 let dir = std::env::temp_dir().join(format!("{prefix}-{}-{suffix}", std::process::id()));
1945 fs::create_dir_all(&dir).unwrap();
1946 dir
1947 }
1948
1949 fn touch(dir: &Path, name: &str) -> PathBuf {
1950 let path = dir.join(name);
1951 fs::write(&path, b"test").unwrap();
1952 path
1953 }
1954
1955 fn zimage_model_paths(
1956 transformer: PathBuf,
1957 transformer_shards: Vec<PathBuf>,
1958 vae: PathBuf,
1959 text_tokenizer: Option<PathBuf>,
1960 ) -> ModelPaths {
1961 ModelPaths {
1962 transformer,
1963 transformer_shards,
1964 vae,
1965 spatial_upscaler: None,
1966 temporal_upscaler: None,
1967 distilled_lora: None,
1968 t5_encoder: None,
1969 clip_encoder: None,
1970 t5_tokenizer: None,
1971 clip_tokenizer: None,
1972 clip_encoder_2: None,
1973 clip_tokenizer_2: None,
1974 text_encoder_files: vec![],
1975 text_tokenizer,
1976 decoder: None,
1977 }
1978 }
1979
1980 #[test]
1981 fn zimage_safetensors_backend_accepts_civitai_diffusion_prefix() {
1982 use candle_nn::var_builder::SimpleBackend;
1983 use safetensors::tensor::{serialize_to_file, Dtype as SafeDtype, TensorView};
1984 use std::collections::HashMap;
1985
1986 fn f32_bytes(values: &[f32]) -> Vec<u8> {
1987 values
1988 .iter()
1989 .flat_map(|value| value.to_le_bytes())
1990 .collect()
1991 }
1992
1993 let dir = temp_test_dir("mold-zimage-prefix-backend");
1994 let path = dir.join("zimage.safetensors");
1995 let data = f32_bytes(&[42.0]);
1996 let qkv = f32_bytes(&[1.0, 2.0, 3.0]);
1997 let out = f32_bytes(&[7.0]);
1998 let q_norm = f32_bytes(&[8.0]);
1999 let k_norm = f32_bytes(&[9.0]);
2000 let mut tensors = HashMap::new();
2001 tensors.insert(
2002 format!("{ZIMAGE_SINGLE_FILE_PREFIX}t_embedder.mlp.0.weight"),
2003 TensorView::new(SafeDtype::F32, vec![1, 1], data.as_slice()).unwrap(),
2004 );
2005 tensors.insert(
2006 format!("{ZIMAGE_SINGLE_FILE_PREFIX}x_embedder.weight"),
2007 TensorView::new(SafeDtype::F32, vec![1, 1], data.as_slice()).unwrap(),
2008 );
2009 tensors.insert(
2010 format!("{ZIMAGE_SINGLE_FILE_PREFIX}noise_refiner.0.attention.qkv.weight"),
2011 TensorView::new(SafeDtype::F32, vec![3, 1], qkv.as_slice()).unwrap(),
2012 );
2013 tensors.insert(
2014 format!("{ZIMAGE_SINGLE_FILE_PREFIX}noise_refiner.0.attention.out.weight"),
2015 TensorView::new(SafeDtype::F32, vec![1, 1], out.as_slice()).unwrap(),
2016 );
2017 tensors.insert(
2018 format!("{ZIMAGE_SINGLE_FILE_PREFIX}noise_refiner.0.attention.q_norm.weight"),
2019 TensorView::new(SafeDtype::F32, vec![1], q_norm.as_slice()).unwrap(),
2020 );
2021 tensors.insert(
2022 format!("{ZIMAGE_SINGLE_FILE_PREFIX}noise_refiner.0.attention.k_norm.weight"),
2023 TensorView::new(SafeDtype::F32, vec![1], k_norm.as_slice()).unwrap(),
2024 );
2025 serialize_to_file(&tensors, &None, &path).unwrap();
2026
2027 let st = unsafe { candle_core::safetensors::MmapedSafetensors::multi(&[path.as_path()]) }
2028 .unwrap();
2029 let backend = ZImageSafetensorsBackend::new(st);
2030 assert!(backend.contains_tensor("t_embedder.mlp.0.weight"));
2031 let tensor = backend
2032 .get_unchecked("t_embedder.mlp.0.weight", DType::F32, &Device::Cpu)
2033 .unwrap();
2034 assert_eq!(tensor.to_vec2::<f32>().unwrap(), vec![vec![42.0]]);
2035 assert!(backend.contains_tensor("all_x_embedder.2-1.weight"));
2036 let alias_tensor = backend
2037 .get_unchecked("all_x_embedder.2-1.weight", DType::F32, &Device::Cpu)
2038 .unwrap();
2039 assert_eq!(alias_tensor.to_vec2::<f32>().unwrap(), vec![vec![42.0]]);
2040 assert!(backend.contains_tensor("noise_refiner.0.attention.to_q.weight"));
2041 assert!(backend.contains_tensor("noise_refiner.0.attention.to_k.weight"));
2042 assert!(backend.contains_tensor("noise_refiner.0.attention.to_v.weight"));
2043 let k = backend
2044 .get(
2045 Shape::from((1, 1)),
2046 "noise_refiner.0.attention.to_k.weight",
2047 candle_nn::Init::Const(0.0),
2048 DType::F32,
2049 &Device::Cpu,
2050 )
2051 .unwrap();
2052 assert_eq!(k.to_vec2::<f32>().unwrap(), vec![vec![2.0]]);
2053 let out = backend
2054 .get_unchecked(
2055 "noise_refiner.0.attention.to_out.0.weight",
2056 DType::F32,
2057 &Device::Cpu,
2058 )
2059 .unwrap();
2060 assert_eq!(out.to_vec2::<f32>().unwrap(), vec![vec![7.0]]);
2061 let q_norm = backend
2062 .get_unchecked(
2063 "noise_refiner.0.attention.norm_q.weight",
2064 DType::F32,
2065 &Device::Cpu,
2066 )
2067 .unwrap();
2068 assert_eq!(q_norm.to_vec1::<f32>().unwrap(), vec![8.0]);
2069 let k_norm = backend
2070 .get_unchecked(
2071 "noise_refiner.0.attention.norm_k.weight",
2072 DType::F32,
2073 &Device::Cpu,
2074 )
2075 .unwrap();
2076 assert_eq!(k_norm.to_vec1::<f32>().unwrap(), vec![9.0]);
2077
2078 let _ = std::fs::remove_dir_all(dir);
2079 }
2080
2081 #[test]
2082 fn zimage_vae_backend_accepts_bare_ldm_vae_keys() {
2083 use candle_nn::var_builder::SimpleBackend;
2084 use safetensors::tensor::{serialize_to_file, Dtype as SafeDtype, TensorView};
2085 use std::collections::HashMap;
2086
2087 fn f32_bytes(values: &[f32]) -> Vec<u8> {
2088 values
2089 .iter()
2090 .flat_map(|value| value.to_le_bytes())
2091 .collect()
2092 }
2093
2094 let dir = temp_test_dir("mold-zimage-vae-backend");
2095 let path = dir.join("vae.safetensors");
2096 let norm = f32_bytes(&[5.0]);
2097 let conv = f32_bytes(&[7.0]);
2098 let attn_q = f32_bytes(&[1.0, 2.0, 3.0, 4.0]);
2099 let mut tensors = HashMap::new();
2100 tensors.insert(
2101 "encoder.down.0.block.0.norm1.weight".to_string(),
2102 TensorView::new(SafeDtype::F32, vec![1], norm.as_slice()).unwrap(),
2103 );
2104 tensors.insert(
2105 "decoder.norm_out.weight".to_string(),
2106 TensorView::new(SafeDtype::F32, vec![1], conv.as_slice()).unwrap(),
2107 );
2108 tensors.insert(
2109 "encoder.mid.attn_1.q.weight".to_string(),
2110 TensorView::new(SafeDtype::F32, vec![2, 2, 1, 1], attn_q.as_slice()).unwrap(),
2111 );
2112 serialize_to_file(&tensors, &None, &path).unwrap();
2113
2114 let st = unsafe { candle_core::safetensors::MmapedSafetensors::multi(&[path.as_path()]) }
2115 .unwrap();
2116 let backend = ZImageVaeSafetensorsBackend::new(st);
2117
2118 assert!(backend.contains_tensor("encoder.down_blocks.0.resnets.0.norm1.weight"));
2119 let norm = backend
2120 .get_unchecked(
2121 "encoder.down_blocks.0.resnets.0.norm1.weight",
2122 DType::F32,
2123 &Device::Cpu,
2124 )
2125 .unwrap();
2126 assert_eq!(norm.to_vec1::<f32>().unwrap(), vec![5.0]);
2127
2128 assert!(backend.contains_tensor("decoder.conv_norm_out.weight"));
2129 let conv = backend
2130 .get_unchecked("decoder.conv_norm_out.weight", DType::F32, &Device::Cpu)
2131 .unwrap();
2132 assert_eq!(conv.to_vec1::<f32>().unwrap(), vec![7.0]);
2133 let q = backend
2134 .get(
2135 Shape::from((2, 2)),
2136 "encoder.mid_block.attentions.0.to_q.weight",
2137 candle_nn::Init::Const(0.0),
2138 DType::F32,
2139 &Device::Cpu,
2140 )
2141 .unwrap();
2142 assert_eq!(
2143 q.to_vec2::<f32>().unwrap(),
2144 vec![vec![1.0, 2.0], vec![3.0, 4.0]]
2145 );
2146
2147 let _ = std::fs::remove_dir_all(dir);
2148 }
2149
2150 #[test]
2151 fn zimage_vae_cpu_tensor_backend_preserves_aliases_and_reshape() {
2152 use candle_nn::var_builder::SimpleBackend;
2153 use std::collections::HashMap;
2154
2155 let device = Device::Cpu;
2156 let mut tensors = HashMap::new();
2157 tensors.insert(
2158 "encoder.down.0.block.0.norm1.weight".to_string(),
2159 Tensor::new(&[5.0f32], &device).unwrap(),
2160 );
2161 tensors.insert(
2162 "encoder.mid.attn_1.q.weight".to_string(),
2163 Tensor::new(&[1.0f32, 2.0, 3.0, 4.0], &device)
2164 .unwrap()
2165 .reshape((2, 2, 1, 1))
2166 .unwrap(),
2167 );
2168
2169 let backend = ZImageVaeSafetensorsBackend::from_cpu_tensors(Arc::new(tensors));
2170
2171 assert!(backend.contains_tensor("encoder.down_blocks.0.resnets.0.norm1.weight"));
2172 let norm = backend
2173 .get_unchecked(
2174 "encoder.down_blocks.0.resnets.0.norm1.weight",
2175 DType::F32,
2176 &device,
2177 )
2178 .unwrap();
2179 assert_eq!(norm.to_vec1::<f32>().unwrap(), vec![5.0]);
2180
2181 let q = backend
2182 .get(
2183 Shape::from((2, 2)),
2184 "encoder.mid_block.attentions.0.to_q.weight",
2185 candle_nn::Init::Const(0.0),
2186 DType::F32,
2187 &device,
2188 )
2189 .unwrap();
2190 assert_eq!(
2191 q.to_vec2::<f32>().unwrap(),
2192 vec![vec![1.0, 2.0], vec![3.0, 4.0]]
2193 );
2194 }
2195
2196 #[test]
2197 fn zimage_loads_vae_tensors_through_shared_pool() {
2198 use safetensors::tensor::{serialize_to_file, Dtype as SafeDtype, TensorView};
2199 use std::collections::HashMap;
2200
2201 let dir = temp_test_dir("mold-zimage-vae-pool");
2202 let vae_path = dir.join("vae.safetensors");
2203 let weight = 1.0f32.to_le_bytes();
2204 let mut tensors = HashMap::new();
2205 tensors.insert(
2206 "encoder.conv_in.weight".to_string(),
2207 TensorView::new(SafeDtype::F32, vec![1], weight.as_slice()).unwrap(),
2208 );
2209 serialize_to_file(&tensors, &None, &vae_path).unwrap();
2210
2211 let shared_pool = Arc::new(Mutex::new(SharedPool::new()));
2212 let pooled = shared_pool
2213 .lock()
2214 .unwrap()
2215 .load_safetensors_cpu_tensors(std::slice::from_ref(&vae_path))
2216 .unwrap()
2217 .unwrap();
2218
2219 let engine = ZImageEngine::new(
2220 "z-image-turbo:q4".to_string(),
2221 zimage_model_paths(
2222 dir.join("transformer.gguf"),
2223 vec![],
2224 vae_path,
2225 Some(dir.join("tokenizer.json")),
2226 ),
2227 None,
2228 LoadStrategy::Sequential,
2229 0,
2230 false,
2231 Some(shared_pool),
2232 );
2233
2234 let loaded = engine.load_vae_cpu_tensors().unwrap().unwrap();
2235
2236 assert!(Arc::ptr_eq(&pooled, &loaded));
2237 fs::remove_dir_all(dir).ok();
2238 }
2239
2240 #[test]
2241 fn zimage_recipe_text_encoder_defaults_to_bf16_variant() {
2242 let recipe_paths = vec![std::path::PathBuf::from(
2243 "/models/cv-2442439/z-image/civitai/2442439/zImageTurbo_turbo_txt.safetensors",
2244 )];
2245 let shared_paths = vec![std::path::PathBuf::from(
2246 "/models/shared/z-image/text_encoder/model-00001-of-00003.safetensors",
2247 )];
2248
2249 assert_eq!(zimage_qwen3_preference(None, &recipe_paths), Some("bf16"));
2250 assert_eq!(zimage_qwen3_preference(None, &shared_paths), None);
2251 assert_eq!(
2252 zimage_qwen3_preference(Some("q8"), &recipe_paths),
2253 Some("q8")
2254 );
2255 assert_eq!(
2256 zimage_qwen3_preference(Some("auto"), &recipe_paths),
2257 Some("auto")
2258 );
2259 }
2260
2261 #[test]
2262 fn latent_dimensions() {
2263 assert_eq!(2 * (1024 / 16), 128);
2265 assert_eq!(2 * (512 / 16), 64);
2267 assert_eq!(2 * (768 / 16), 96);
2269 }
2270
2271 #[test]
2274 fn qwen3_on_gpu_on_24gb_with_q8_drop_reload() {
2275 assert!(should_use_gpu(
2278 true,
2279 false,
2280 17_000_000_000,
2281 QWEN3_FP16_VRAM_THRESHOLD
2282 ));
2283 }
2284
2285 #[test]
2286 fn qwen3_on_gpu_on_24gb_with_q4_drop_reload() {
2287 assert!(should_use_gpu(
2290 true,
2291 false,
2292 19_000_000_000,
2293 QWEN3_FP16_VRAM_THRESHOLD
2294 ));
2295 }
2296
2297 #[test]
2298 fn qwen3_on_cpu_with_bf16_transformer() {
2299 assert!(!should_use_gpu(
2302 true,
2303 false,
2304 400_000_000,
2305 QWEN3_FP16_VRAM_THRESHOLD
2306 ));
2307 }
2308
2309 #[test]
2310 fn qwen3_on_gpu_on_48gb_card() {
2311 assert!(should_use_gpu(
2312 true,
2313 false,
2314 40_000_000_000,
2315 QWEN3_FP16_VRAM_THRESHOLD
2316 ));
2317 }
2318
2319 #[test]
2320 fn qwen3_on_gpu_on_metal() {
2321 assert!(should_use_gpu(false, true, 0, QWEN3_FP16_VRAM_THRESHOLD));
2323 }
2324
2325 #[test]
2326 fn vae_on_gpu_when_plenty_of_vram() {
2327 assert!(should_use_gpu(
2328 true,
2329 false,
2330 17_000_000_000,
2331 VAE_DECODE_VRAM_THRESHOLD
2332 ));
2333 }
2334
2335 #[test]
2336 fn eager_vae_weight_load_threshold_is_below_decode_workspace_threshold() {
2337 const {
2338 assert!(VAE_WEIGHT_LOAD_VRAM_THRESHOLD < VAE_DECODE_VRAM_THRESHOLD);
2339 }
2340 assert!(should_use_gpu(
2341 true,
2342 false,
2343 1_000_000_000,
2344 VAE_WEIGHT_LOAD_VRAM_THRESHOLD
2345 ));
2346 }
2347
2348 #[test]
2349 fn vae_on_cpu_when_vram_tight() {
2350 assert!(!should_use_gpu(
2351 true,
2352 false,
2353 5_400_000_000,
2354 VAE_DECODE_VRAM_THRESHOLD
2355 ));
2356 }
2357
2358 #[test]
2359 fn vae_on_gpu_on_metal() {
2360 assert!(should_use_gpu(false, true, 0, VAE_DECODE_VRAM_THRESHOLD));
2362 }
2363
2364 #[test]
2367 fn qwen3_threshold_allows_gpu_on_24gb_with_quantized_xformer() {
2368 let threshold = std::hint::black_box(QWEN3_FP16_VRAM_THRESHOLD);
2371 assert!(threshold < 17_000_000_000);
2372 }
2373
2374 #[test]
2375 fn qwen3_threshold_exceeds_encoder_size() {
2376 let threshold = std::hint::black_box(QWEN3_FP16_VRAM_THRESHOLD);
2377 assert!(threshold > 8_200_000_000);
2378 }
2379
2380 #[test]
2381 fn vae_threshold_accounts_for_decode_workspace() {
2382 let threshold = std::hint::black_box(VAE_DECODE_VRAM_THRESHOLD);
2383 assert!(threshold > 160_000_000);
2384 assert!(threshold < 15_000_000_000);
2385 }
2386
2387 #[test]
2388 fn zimage_scheduler_uses_shifted_reference_sigmas() {
2389 let image_seq_len = 1024;
2390 let (full, _) = build_zimage_scheduler(9, image_seq_len, None);
2391 let (scheduler, start_index) = build_zimage_scheduler(9, image_seq_len, Some(0.5));
2392 let expected_sigmas = full.sigmas[start_index..].to_vec();
2393 let expected_timesteps = expected_sigmas[..expected_sigmas.len() - 1]
2394 .iter()
2395 .map(|sigma| sigma * 1000.0)
2396 .collect::<Vec<_>>();
2397
2398 assert_eq!(start_index, crate::img2img::img2img_start_index(9, 0.5));
2399 assert_eq!(scheduler.sigmas, expected_sigmas);
2400 assert_eq!(scheduler.timesteps, expected_timesteps);
2401 assert_eq!(scheduler.sigmas.last().copied(), Some(0.0));
2402 }
2403
2404 #[test]
2405 fn zimage_model_timestep_matches_scheduler_timesteps() {
2406 let (scheduler, _) = build_zimage_scheduler(9, 1024, Some(0.5));
2407 let t = model_timestep(&scheduler);
2408 assert!(
2409 (t - (1.0 - scheduler.sigmas[0])).abs() < 1e-10,
2410 "expected model timestep to match 1-sigma semantics, got {t} vs {}",
2411 1.0 - scheduler.sigmas[0]
2412 );
2413 }
2414
2415 #[test]
2416 fn zimage_img2img_source_decode_uses_vae_native_zero_to_one_range() {
2417 let source = include_str!("pipeline.rs")
2418 .split("#[cfg(test)]\nmod tests")
2419 .next()
2420 .expect("pipeline source should include production section");
2421 let decode_sites = source
2422 .split("let source_tensor = img_utils::decode_source_image(")
2423 .skip(1)
2424 .collect::<Vec<_>>();
2425
2426 assert_eq!(decode_sites.len(), 2);
2427 for site in decode_sites {
2428 let args = site
2429 .split(")?;")
2430 .next()
2431 .expect("source decode call should terminate");
2432 assert!(
2433 args.contains("img_utils::NormalizeRange::ZeroToOne"),
2434 "Z-Image source-image encoding must use the VAE-native [0, 1] range"
2435 );
2436 assert!(
2437 !args.contains("img_utils::NormalizeRange::MinusOneToOne"),
2438 "Z-Image source-image encoding must not use [-1, 1] normalization"
2439 );
2440 }
2441 }
2442
2443 #[test]
2444 fn zimage_zero_strength_preserves_terminal_zero_only() {
2445 let (scheduler, start_index) = build_zimage_scheduler(9, 1024, Some(0.0));
2446
2447 assert_eq!(start_index, 9);
2448 assert_eq!(scheduler.sigmas, vec![0.0]);
2449 assert!(scheduler.timesteps.is_empty());
2450 }
2451
2452 #[test]
2453 fn tensor_stats_summary_reports_expected_values() {
2454 let tensor =
2455 Tensor::from_vec(vec![1.0f32, -1.0, 3.0, -3.0], (1, 1, 2, 2), &Device::Cpu).unwrap();
2456 let summary = tensor_stats_summary("probe", &tensor).unwrap();
2457
2458 assert!(summary.contains("probe:"));
2459 assert!(summary.contains("mean=0.00000"));
2460 assert!(summary.contains("min=-3.00000"));
2461 assert!(summary.contains("max=3.00000"));
2462 assert!(summary.contains("rms=2.23607"));
2463 }
2464
2465 #[test]
2466 fn zimage_transformer_paths_prefer_shards_when_present() {
2467 let dir = temp_test_dir("mold-zimage-shards");
2468 let shard_a = touch(&dir, "transformer-00001-of-00002.safetensors");
2469 let shard_b = touch(&dir, "transformer-00002-of-00002.safetensors");
2470 let engine = ZImageEngine::new(
2471 "z-image-turbo:bf16".to_string(),
2472 zimage_model_paths(
2473 dir.join("transformer.safetensors"),
2474 vec![shard_a.clone(), shard_b.clone()],
2475 dir.join("vae.safetensors"),
2476 Some(dir.join("tokenizer.json")),
2477 ),
2478 None,
2479 LoadStrategy::Sequential,
2480 0,
2481 false,
2482 None,
2483 );
2484
2485 assert_eq!(engine.transformer_paths(), vec![shard_a, shard_b]);
2486
2487 fs::remove_dir_all(dir).ok();
2488 }
2489
2490 #[test]
2491 fn zimage_validate_paths_accepts_existing_files() {
2492 let dir = temp_test_dir("mold-zimage-validate-ok");
2493 let shard_a = touch(&dir, "transformer-00001-of-00002.safetensors");
2494 let shard_b = touch(&dir, "transformer-00002-of-00002.safetensors");
2495 let vae = touch(&dir, "vae.safetensors");
2496 let tokenizer = touch(&dir, "tokenizer.json");
2497 let gguf = touch(&dir, "transformer.gguf");
2498
2499 let sharded = ZImageEngine::new(
2500 "z-image-turbo:bf16".to_string(),
2501 zimage_model_paths(
2502 dir.join("transformer.safetensors"),
2503 vec![shard_a, shard_b],
2504 vae.clone(),
2505 Some(tokenizer.clone()),
2506 ),
2507 None,
2508 LoadStrategy::Sequential,
2509 0,
2510 false,
2511 None,
2512 );
2513 assert_eq!(sharded.validate_paths().unwrap(), tokenizer);
2514 assert!(!sharded.detect_is_gguf());
2515
2516 let quantized = ZImageEngine::new(
2517 "z-image-turbo:q4".to_string(),
2518 zimage_model_paths(gguf, vec![], vae, Some(dir.join("tokenizer.json"))),
2519 None,
2520 LoadStrategy::Sequential,
2521 0,
2522 false,
2523 None,
2524 );
2525 assert!(quantized.detect_is_gguf());
2526
2527 fs::remove_dir_all(dir).ok();
2528 }
2529
2530 #[test]
2531 fn zimage_lora_requests_use_sequential_generation_path() {
2532 let dir = temp_test_dir("mold-zimage-lora-sequential");
2533 let mut engine = ZImageEngine::new(
2534 "z-image-turbo:q8".to_string(),
2535 zimage_model_paths(
2536 dir.join("transformer.gguf"),
2537 vec![],
2538 dir.join("vae.safetensors"),
2539 Some(dir.join("tokenizer.json")),
2540 ),
2541 None,
2542 LoadStrategy::Eager,
2543 0,
2544 false,
2545 None,
2546 );
2547 engine.pending_loras = vec![LoraWeight {
2548 path: dir.join("adapter.safetensors").display().to_string(),
2549 scale: 1.0,
2550 }];
2551
2552 assert!(
2553 engine.uses_sequential_generate_path(),
2554 "Z-Image LoRA requests should use staged load-use-drop generation \
2555 so VAE/text encoders are not co-resident with the LoRA-merged transformer"
2556 );
2557
2558 fs::remove_dir_all(dir).ok();
2559 }
2560
2561 #[test]
2562 fn zimage_sequential_path_drops_eager_components_before_generation() {
2563 let source = include_str!("pipeline.rs");
2564 let sequential_branch = source
2565 .split("// Eager mode: use pre-loaded components")
2566 .next()
2567 .expect("generate_inner should contain eager-mode marker");
2568
2569 assert!(
2570 sequential_branch.contains("self.base.unload();")
2571 && sequential_branch.contains("return self.generate_sequential(req);"),
2572 "Z-Image LoRA/offload sequential generation must drop eager-loaded \
2573 components before loading staged components"
2574 );
2575 }
2576
2577 #[test]
2578 fn zimage_eager_path_reloads_after_sequential_generation_unloads_components() {
2579 let source = include_str!("pipeline.rs");
2580 let eager_branch = source
2581 .split("// Eager mode: use pre-loaded components")
2582 .nth(1)
2583 .expect("generate_inner should contain eager-mode branch");
2584 let reload_idx = eager_branch
2585 .find("self.load()?;")
2586 .expect("eager branch should reload an unloaded cached engine");
2587 let guard_idx = eager_branch
2588 .find("bail!(\"model not loaded")
2589 .expect("eager branch should retain a final loaded-state guard");
2590
2591 assert!(
2592 reload_idx < guard_idx,
2593 "Z-Image eager generation must reload after a prior LoRA/offload \
2594 sequential request unloads cached components"
2595 );
2596 }
2597
2598 #[test]
2599 fn zimage_forced_offload_uses_sequential_generation_path() {
2600 let dir = temp_test_dir("mold-zimage-offload-sequential");
2601 let engine = ZImageEngine::new(
2602 "z-image-turbo:bf16".to_string(),
2603 zimage_model_paths(
2604 dir.join("transformer.safetensors"),
2605 vec![],
2606 dir.join("vae.safetensors"),
2607 Some(dir.join("tokenizer.json")),
2608 ),
2609 None,
2610 LoadStrategy::Eager,
2611 0,
2612 true,
2613 None,
2614 );
2615
2616 assert!(
2617 engine.uses_sequential_generate_path(),
2618 "Z-Image --offload requests must reach the engine and select the \
2619 staged generation path instead of being silently ignored"
2620 );
2621
2622 fs::remove_dir_all(dir).ok();
2623 }
2624
2625 #[test]
2626 fn zimage_offload_decision_gates_current_unsupported_cases() {
2627 assert_eq!(
2628 zimage_offload_decision(false, false, false),
2629 ZImageOffloadDecision::Disabled
2630 );
2631 assert_eq!(
2632 zimage_offload_decision(true, false, false),
2633 ZImageOffloadDecision::Selected
2634 );
2635 assert!(matches!(
2636 zimage_offload_decision(true, true, false),
2637 ZImageOffloadDecision::Unsupported(reason)
2638 if reason.contains("GGUF variants")
2639 ));
2640 assert!(matches!(
2641 zimage_offload_decision(true, false, true),
2642 ZImageOffloadDecision::Unsupported(reason)
2643 if reason.contains("LoRA")
2644 ));
2645 }
2646
2647 #[test]
2648 fn zimage_selected_bf16_offload_reaches_runtime_loader() {
2649 let dir = temp_test_dir("mold-zimage-offload-loader");
2650 let mut engine = ZImageEngine::new(
2651 "z-image-turbo:bf16".to_string(),
2652 zimage_model_paths(
2653 touch(&dir, "transformer.safetensors"),
2654 vec![],
2655 touch(&dir, "vae.safetensors"),
2656 Some(touch(&dir, "tokenizer.json")),
2657 ),
2658 None,
2659 LoadStrategy::Sequential,
2660 0,
2661 true,
2662 None,
2663 );
2664 let req = GenerateRequest {
2665 prompt: "a cat".to_string(),
2666 negative_prompt: None,
2667 model: "z-image-turbo:bf16".to_string(),
2668 width: 64,
2669 height: 64,
2670 steps: 1,
2671 guidance: 0.0,
2672 seed: Some(1),
2673 batch_size: 1,
2674 output_format: None,
2675 embed_metadata: None,
2676 scheduler: None,
2677 cfg_plus: None,
2678 source_image: None,
2679 edit_images: None,
2680 strength: 1.0,
2681 mask_image: None,
2682 control_image: None,
2683 control_model: None,
2684 control_scale: 1.0,
2685 expand: None,
2686 original_prompt: None,
2687 lora: None,
2688 frames: None,
2689 fps: None,
2690 upscale_model: None,
2691 gif_preview: false,
2692 enable_audio: None,
2693 audio_file: None,
2694 audio_file_path: None,
2695 source_video: None,
2696 source_video_path: None,
2697 keyframes: None,
2698 pipeline: None,
2699 loras: None,
2700 retake_range: None,
2701 spatial_upscale: None,
2702 temporal_upscale: None,
2703 placement: None,
2704 };
2705
2706 let err = engine.generate_sequential(&req).unwrap_err().to_string();
2707
2708 assert!(
2709 !err.contains("streaming is not implemented yet"),
2710 "selected BF16 offload must reach the runtime loader, got: {err}"
2711 );
2712 fs::remove_dir_all(dir).ok();
2713 }
2714
2715 #[test]
2716 fn zimage_validate_paths_requires_text_tokenizer() {
2717 let dir = temp_test_dir("mold-zimage-validate-missing");
2718 let engine = ZImageEngine::new(
2719 "z-image-turbo:q4".to_string(),
2720 zimage_model_paths(
2721 dir.join("transformer.gguf"),
2722 vec![],
2723 dir.join("vae.safetensors"),
2724 None,
2725 ),
2726 None,
2727 LoadStrategy::Sequential,
2728 0,
2729 false,
2730 None,
2731 );
2732
2733 let err = engine.validate_paths().unwrap_err();
2734 assert!(err.to_string().contains("text tokenizer path required"));
2735
2736 fs::remove_dir_all(dir).ok();
2737 }
2738
2739 #[test]
2740 fn zimage_loads_qwen3_tokenizer_through_shared_pool() {
2741 let dir = temp_test_dir("mold-zimage-tokenizer-pool");
2742 let tokenizer_path = dir.join("tokenizer.json");
2743 tokenizers::Tokenizer::new(BPE::default())
2744 .save(&tokenizer_path, false)
2745 .unwrap();
2746
2747 let shared_pool = Arc::new(Mutex::new(SharedPool::new()));
2748 let pooled = shared_pool
2749 .lock()
2750 .unwrap()
2751 .load_tokenizer(&tokenizer_path)
2752 .unwrap();
2753
2754 let engine = ZImageEngine::new(
2755 "z-image-turbo:q4".to_string(),
2756 zimage_model_paths(
2757 dir.join("transformer.gguf"),
2758 vec![],
2759 dir.join("vae.safetensors"),
2760 Some(tokenizer_path.clone()),
2761 ),
2762 None,
2763 LoadStrategy::Sequential,
2764 0,
2765 false,
2766 Some(shared_pool),
2767 );
2768
2769 let loaded = engine.load_text_tokenizer(&tokenizer_path).unwrap();
2770
2771 assert!(Arc::ptr_eq(&pooled, &loaded));
2772 fs::remove_dir_all(dir).ok();
2773 }
2774}