Skip to main content

mold_inference/flux2/
pipeline.rs

1//! Flux.2 Klein inference engine (4B and 9B variants).
2//!
3//! Follows the same Eager + Sequential loading pattern as FluxEngine and ZImageEngine.
4//!
5//! Key differences from FLUX.1:
6//! - Uses Qwen3 text encoder (not T5 + CLIP)
7//!   - Klein-4B: Qwen3-4B (hidden=2560), stacked layers → 7680-dim context
8//!   - Klein-9B: Qwen3-8B (hidden=4096), stacked layers → 12288-dim context
9//! - VAE has latent_channels=32 (not 16)
10//! - Transformer has 128 input channels (not 64)
11//! - 4D RoPE (not 3D)
12//! - Klein is distilled (no guidance embedding)
13//! - No pooled text vector input
14//! - Linear timestep schedule (distilled, no time-shifting)
15
16use anyhow::{bail, Result};
17use candle_core::{DType, Device, IndexOp, Tensor};
18use candle_nn::VarBuilder;
19use mold_core::{GenerateRequest, GenerateResponse, ImageData, LoraWeight, ModelPaths};
20use std::collections::HashMap;
21use std::path::{Path, PathBuf};
22use std::sync::{Arc, Mutex};
23use std::time::Instant;
24use tokenizers::Tokenizer;
25
26use super::sampling::{self, Flux2State};
27use super::transformer::{Flux2Config, Flux2TransformerWrapper};
28use super::vae::{Flux2AutoEncoder, Flux2VaeConfig};
29
30use crate::cache::{
31    clear_cache, get_or_insert_cached_tensor, prompt_text_key, restore_cached_tensor, CachedTensor,
32    LruCache, DEFAULT_PROMPT_CACHE_CAPACITY,
33};
34use crate::device::{
35    check_memory_budget, effective_device_ref, fmt_gb, free_vram_bytes, memory_status_string,
36    preflight_memory_check, usable_free_vram_bytes,
37};
38use crate::encoders;
39use crate::engine::{rand_seed, InferenceEngine, LoadStrategy};
40use crate::engine_base::EngineBase;
41use crate::image::{build_output_metadata, encode_image};
42use crate::progress::{ProgressCallback, ProgressReporter};
43
44// ---------------------------------------------------------------------------
45// Loaded state
46// ---------------------------------------------------------------------------
47
48/// Loaded Flux.2 model components, ready for inference.
49struct LoadedFlux2 {
50    /// None after being dropped for VAE decode VRAM; reloaded on next generate.
51    transformer: Option<Flux2TransformerWrapper>,
52    text_encoder: encoders::qwen3::Qwen3Encoder,
53    vae: Flux2AutoEncoder,
54    /// GPU device for transformer + VAE
55    device: Device,
56    dtype: DType,
57    /// Effective VAE dtype after `MOLD_VAE_DTYPE` resolution. May differ from
58    /// `dtype` when fp32 VAE decode is forced to suppress banding artifacts.
59    /// Captured at load time; sequential reloads re-resolve per request.
60    vae_dtype: DType,
61}
62
63// ---------------------------------------------------------------------------
64// Engine
65// ---------------------------------------------------------------------------
66
67/// Flux.2 Klein inference engine (4B and 9B variants) backed by candle.
68pub struct Flux2Engine {
69    base: EngineBase<LoadedFlux2>,
70    /// Qwen3 variant preference: None/"auto" = VRAM-based, "bf16" = force BF16, "q8"/etc = specific.
71    qwen3_variant: Option<String>,
72    /// Force block-level transformer offload once the Flux.2 runtime supports
73    /// streaming BF16 blocks. Plumbed now so the request is explicit instead
74    /// of being silently treated as a regular dense load.
75    offload: bool,
76    prompt_cache: Mutex<LruCache<String, CachedTensor>>,
77    /// Per-request placement override. Set at the start of `generate()`,
78    /// cleared on exit. `None` preserves the existing VRAM-aware auto logic.
79    pending_placement: Option<mold_core::types::DevicePlacement>,
80    /// Per-request LoRA stack (effective: zero-scale entries already filtered).
81    /// Set at the start of `generate()`, cleared on exit. Read by
82    /// `load_transformer` / `reload_transformer_if_needed` to decide whether
83    /// to wrap the transformer's `VarBuilder` with a `Flux2LoraBackend`.
84    pending_loras: Vec<LoraWeight>,
85    shared_pool: Option<Arc<Mutex<crate::shared_pool::SharedPool>>>,
86}
87
88/// Resolve the effective LoRA list for a request. Mirrors the FLUX helper of
89/// the same shape: `loras` (plural) wins over `lora` (singular) when both are
90/// set, and zero-scale entries are filtered out so they don't trigger a
91/// transformer rebuild for nothing.
92pub(crate) fn effective_flux2_loras(req: &GenerateRequest) -> Vec<LoraWeight> {
93    /// Threshold below which a LoRA scale is treated as off (matches FLUX).
94    const ZERO_SCALE_EPS: f64 = 1e-8;
95
96    let raw: Vec<LoraWeight> = if let Some(plural) = &req.loras {
97        if !plural.is_empty() {
98            plural.clone()
99        } else {
100            req.lora.iter().cloned().collect()
101        }
102    } else {
103        req.lora.iter().cloned().collect()
104    };
105    raw.into_iter()
106        .filter(|w| {
107            let keep = w.scale.abs() > ZERO_SCALE_EPS;
108            if !keep {
109                tracing::debug!(
110                    path = w.path.as_str(),
111                    scale = w.scale,
112                    "dropping zero-scale Flux.2 LoRA"
113                );
114            }
115            keep
116        })
117        .collect()
118}
119
120#[derive(Debug, PartialEq, Eq)]
121enum Flux2OffloadDecision {
122    Disabled,
123    Selected,
124    Unsupported(&'static str),
125}
126
127fn flux2_offload_decision(
128    forced_offload: bool,
129    is_gguf: bool,
130    has_lora: bool,
131) -> Flux2OffloadDecision {
132    if !forced_offload {
133        return Flux2OffloadDecision::Disabled;
134    }
135    if is_gguf {
136        return Flux2OffloadDecision::Unsupported(
137            "Flux.2 block-level offload is only planned for BF16/FP transformers; \
138             GGUF variants already use quantized transformer paths",
139        );
140    }
141    if has_lora {
142        return Flux2OffloadDecision::Unsupported(
143            "Flux.2 block-level offload with LoRA is not wired yet; \
144             LoRA merge/bypass semantics need a dedicated offload design",
145        );
146    }
147    Flux2OffloadDecision::Selected
148}
149
150impl Flux2Engine {
151    /// Create a new Flux2Engine. Does not load models until `load()` is called.
152    pub fn new(
153        model_name: String,
154        paths: ModelPaths,
155        qwen3_variant: Option<String>,
156        load_strategy: LoadStrategy,
157        gpu_ordinal: usize,
158        offload: bool,
159        shared_pool: Option<Arc<Mutex<crate::shared_pool::SharedPool>>>,
160    ) -> Self {
161        Self {
162            base: EngineBase::new(model_name, paths, load_strategy, gpu_ordinal),
163            qwen3_variant,
164            offload,
165            prompt_cache: Mutex::new(LruCache::new(DEFAULT_PROMPT_CACHE_CAPACITY)),
166            pending_placement: None,
167            pending_loras: Vec::new(),
168            shared_pool,
169        }
170    }
171
172    /// Construct a Flux.2 engine from a Civitai / ComfyUI single-file
173    /// safetensors checkpoint (BFL-native naming, every key prefixed
174    /// `model.diffusion_model.`).
175    ///
176    /// The transformer is the single-file checkpoint itself; the VAE,
177    /// Qwen3 text encoder, and tokenizer arrive via companion paths
178    /// resolved by the catalog bridge before the engine is constructed.
179    /// The header is not peeked here — `load_transformer` re-detects the
180    /// format at load time so a per-engine error surfaces in the same
181    /// place as every other transformer load failure.
182    #[allow(clippy::too_many_arguments)]
183    pub fn from_single_file(
184        model_name: String,
185        transformer_path: PathBuf,
186        vae_path: PathBuf,
187        text_encoder_files: Vec<PathBuf>,
188        text_tokenizer: Option<PathBuf>,
189        qwen3_variant: Option<String>,
190        load_strategy: LoadStrategy,
191        gpu_ordinal: usize,
192        offload: bool,
193        shared_pool: Option<Arc<Mutex<crate::shared_pool::SharedPool>>>,
194    ) -> Result<Self> {
195        if !transformer_path.exists() {
196            bail!(
197                "single-file Flux.2 checkpoint not found: {}",
198                transformer_path.display()
199            );
200        }
201
202        let paths = ModelPaths {
203            transformer: transformer_path,
204            transformer_shards: Vec::new(),
205            vae: vae_path,
206            spatial_upscaler: None,
207            temporal_upscaler: None,
208            distilled_lora: None,
209            t5_encoder: None,
210            clip_encoder: None,
211            t5_tokenizer: None,
212            clip_tokenizer: None,
213            clip_encoder_2: None,
214            clip_tokenizer_2: None,
215            text_encoder_files,
216            text_tokenizer,
217            decoder: None,
218        };
219
220        Ok(Self {
221            base: EngineBase::new(model_name, paths, load_strategy, gpu_ordinal),
222            qwen3_variant,
223            offload,
224            prompt_cache: Mutex::new(LruCache::new(DEFAULT_PROMPT_CACHE_CAPACITY)),
225            pending_placement: None,
226            pending_loras: Vec::new(),
227            shared_pool,
228        })
229    }
230
231    /// Select the appropriate transformer config. Header-peeks the
232    /// checkpoint when it's a single-file `.safetensors` to determine
233    /// hidden_size (3072 → Klein-4B, 4096 → Klein-9B). Falls back to
234    /// the model-name heuristic for sharded HF diffusers layouts and
235    /// when header-peek can't find an `img_in.weight` marker (e.g. some
236    /// community FP8 conversions). This is necessary for opaque names
237    /// like `cv:2759597` whose mapping to a Klein variant is only
238    /// recoverable from the file itself.
239    fn resolve_config(&self) -> Flux2Config {
240        if let Some(cfg) = self.detect_config_from_checkpoint() {
241            return cfg;
242        }
243        if self.base.model_name.to_lowercase().contains("9b") {
244            Flux2Config::klein_9b()
245        } else {
246            Flux2Config::klein()
247        }
248    }
249
250    /// Header-peek the transformer file (if it's a single `.safetensors`)
251    /// and pick the config matching its `hidden_size`. Returns `None` for
252    /// sharded loads or when no `img_in.weight` marker is present.
253    fn detect_config_from_checkpoint(&self) -> Option<Flux2Config> {
254        if !self.base.paths.transformer_shards.is_empty() {
255            return None;
256        }
257        let path = &self.base.paths.transformer;
258        let is_safetensors = path
259            .extension()
260            .and_then(|e| e.to_str())
261            .is_some_and(|e| e.eq_ignore_ascii_case("safetensors"));
262        if !is_safetensors {
263            return None;
264        }
265        match super::single_file::detect_hidden_size(path) {
266            Ok(Some(4096)) => Some(Flux2Config::klein_9b()),
267            Ok(Some(3072)) => Some(Flux2Config::klein()),
268            // Anything else: unknown variant, defer to name heuristic.
269            _ => None,
270        }
271    }
272
273    /// Whether this is a Klein-9B model (uses Qwen3-8B text encoder).
274    /// Mirrors `resolve_config` — peek the checkpoint first, fall back
275    /// to the model-name heuristic.
276    fn is_9b(&self) -> bool {
277        if let Some(cfg) = self.detect_config_from_checkpoint() {
278            return cfg.hidden_size == 4096;
279        }
280        self.base.model_name.to_lowercase().contains("9b")
281    }
282
283    /// Return the Qwen3 encoder size enum for this model.
284    fn qwen3_size(&self) -> crate::encoders::variant_resolution::Qwen3Size {
285        if self.is_9b() {
286            crate::encoders::variant_resolution::Qwen3Size::B8
287        } else {
288            crate::encoders::variant_resolution::Qwen3Size::B4
289        }
290    }
291
292    /// Return the BF16 config for the Qwen3 encoder used by this model.
293    fn qwen3_bf16_config(&self) -> encoders::qwen3_bf16::Qwen3BF16Config {
294        if self.is_9b() {
295            encoders::qwen3_bf16::Qwen3BF16Config::qwen3_8b()
296        } else {
297            encoders::qwen3_bf16::Qwen3BF16Config::qwen3_4b()
298        }
299    }
300
301    fn load_text_tokenizer(&self, tokenizer_path: &Path) -> Result<Arc<Tokenizer>> {
302        if let Some(shared_pool) = &self.shared_pool {
303            return shared_pool.lock().unwrap().load_tokenizer(tokenizer_path);
304        }
305        Tokenizer::from_file(tokenizer_path)
306            .map(Arc::new)
307            .map_err(|e| anyhow::anyhow!("failed to load Qwen3 tokenizer: {e}"))
308    }
309
310    fn load_vae_cpu_tensors(&self) -> Result<Option<Arc<HashMap<String, Tensor>>>> {
311        let Some(shared_pool) = &self.shared_pool else {
312            return Ok(None);
313        };
314        shared_pool
315            .lock()
316            .unwrap()
317            .load_safetensors_cpu_tensors(std::slice::from_ref(&self.base.paths.vae))
318    }
319
320    fn load_vae_var_builder<'a>(
321        &self,
322        dtype: DType,
323        device: &Device,
324        component: &str,
325    ) -> Result<VarBuilder<'a>> {
326        if let Some(tensors) = self.load_vae_cpu_tensors()? {
327            return Ok(crate::encoders::park::varbuilder_from_parked(
328                tensors.as_ref(),
329                dtype,
330                device,
331            ));
332        }
333
334        crate::weight_loader::load_safetensors_with_progress(
335            std::slice::from_ref(&self.base.paths.vae),
336            dtype,
337            device,
338            component,
339            &self.base.progress,
340        )
341    }
342
343    fn img2img_source_normalize_range() -> crate::img_utils::NormalizeRange {
344        crate::img_utils::NormalizeRange::MinusOneToOne
345    }
346
347    #[cfg(test)]
348    fn sequential_img2img_preencodes_source() -> bool {
349        true
350    }
351
352    fn uses_sequential_generate_path(&self) -> bool {
353        self.base.load_strategy == LoadStrategy::Sequential
354            || self.offload
355            || !self.pending_loras.is_empty()
356    }
357
358    fn load_sequential_vae(
359        &self,
360        device: &Device,
361        gpu_dtype: DType,
362    ) -> Result<(Flux2AutoEncoder, DType)> {
363        let vae_ref =
364            effective_device_ref(self.pending_placement.as_ref(), |adv| Some(adv.vae), false);
365        let vae_device = crate::device::resolve_device(Some(vae_ref), || Ok(device.clone()))?;
366        self.base.progress.stage_start("Loading VAE (GPU)");
367        let vae_stage = Instant::now();
368        let vae_cfg = Flux2VaeConfig::klein();
369        // Sequential path resolves MOLD_VAE_DTYPE per request — env changes
370        // take effect on the next generate() without an engine reload.
371        let vae_dtype = crate::device::resolve_vae_dtype(gpu_dtype);
372        let vae_vb = self.load_vae_var_builder(vae_dtype, &vae_device, "VAE")?;
373        let vae = Flux2AutoEncoder::new(&vae_cfg, vae_vb)?;
374        self.base
375            .progress
376            .stage_done("Loading VAE (GPU)", vae_stage.elapsed());
377        Ok((vae, vae_dtype))
378    }
379
380    /// Validate that all required paths exist.
381    fn validate_paths(&self) -> Result<std::path::PathBuf> {
382        let text_tokenizer_path = self
383            .base
384            .paths
385            .text_tokenizer
386            .as_ref()
387            .ok_or_else(|| anyhow::anyhow!("text tokenizer path required for Flux.2 models"))?;
388        if !text_tokenizer_path.exists() {
389            bail!(
390                "text tokenizer file not found: {}",
391                text_tokenizer_path.display()
392            );
393        }
394
395        let encoder_paths = self.text_encoder_paths();
396        if encoder_paths.is_empty() {
397            bail!("text encoder paths required for Flux.2 models");
398        }
399        for path in &encoder_paths {
400            if !path.exists() {
401                bail!("text encoder file not found: {}", path.display());
402            }
403        }
404
405        if !self.base.paths.transformer.exists() {
406            bail!(
407                "transformer file not found: {}",
408                self.base.paths.transformer.display()
409            );
410        }
411        if !self.base.paths.vae.exists() {
412            bail!("VAE file not found: {}", self.base.paths.vae.display());
413        }
414
415        Ok(text_tokenizer_path.clone())
416    }
417
418    /// Check if the transformer file is a GGUF (quantized) file.
419    fn is_gguf_transformer(&self) -> bool {
420        self.base
421            .paths
422            .transformer
423            .extension()
424            .and_then(|e| e.to_str())
425            .map(|e| e.eq_ignore_ascii_case("gguf"))
426            .unwrap_or(false)
427    }
428
429    /// Load the transformer from either GGUF or BF16 safetensors.
430    ///
431    /// When `self.pending_loras` is non-empty, every BF16 / GGUF branch wraps
432    /// the underlying tensor source with a Flux.2 LoRA backend so the
433    /// constructed transformer carries `W' = W + scale·B@A` for every
434    /// LoRA-targeted layer (additive across multiple adapters). See
435    /// `super::lora` for the key-mapping rules.
436    fn load_transformer(
437        &self,
438        cfg: &Flux2Config,
439        gpu_dtype: DType,
440        device: &Device,
441    ) -> Result<(Flux2TransformerWrapper, &'static str)> {
442        let has_lora = !self.pending_loras.is_empty();
443        if self.is_gguf_transformer() {
444            if has_lora {
445                // Dequant→merge→requant on every LoRA-affected GGUF tensor.
446                // Non-LoRA tensors stay quantised, untouched.
447                let adapters =
448                    super::lora::load_lora_adapters(&self.pending_loras, &self.base.progress)?;
449                let specs: Vec<super::lora::Flux2LoraSpec<'_>> = adapters
450                    .iter()
451                    .zip(self.pending_loras.iter())
452                    .map(|(adapter, w)| super::lora::Flux2LoraSpec {
453                        adapter: adapter.as_ref(),
454                        scale: w.scale,
455                        path_hash: super::lora::lora_path_hash(&w.path),
456                    })
457                    .collect();
458                let gguf_vb = super::lora::gguf_lora_var_builder_flux2(
459                    &self.base.paths.transformer,
460                    &specs,
461                    device,
462                    &self.base.progress,
463                    None,
464                )?;
465                return Ok((
466                    Flux2TransformerWrapper::Quantized(
467                        super::quantized_transformer::QuantizedFlux2Transformer::new(
468                            cfg, gguf_vb, device,
469                        )?,
470                    ),
471                    "Loading Flux.2 transformer (GPU, GGUF + LoRA)",
472                ));
473            }
474            // Weights stay quantized in VRAM via QMatMul — no dequantization at load
475            // time. A Q4 Klein-9B uses ~6GB VRAM instead of ~18GB with full dequant.
476            let gguf_vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(
477                &self.base.paths.transformer,
478                device,
479            )?;
480            Ok((
481                Flux2TransformerWrapper::Quantized(
482                    super::quantized_transformer::QuantizedFlux2Transformer::new(
483                        cfg, gguf_vb, device,
484                    )?,
485                ),
486                "Loading Flux.2 transformer (GPU, GGUF)",
487            ))
488        } else if self.is_bfl_native_single_file() {
489            // Civitai / ComfyUI single-file checkpoints carry BFL-native
490            // tensor names (`model.diffusion_model.*`); the diffusers
491            // `Flux2Transformer::new` consumer is wrapped over a
492            // `SingleFileBackend` that translates those keys on the fly.
493            tracing::info!(
494                path = %self.base.paths.transformer.display(),
495                "loading Flux.2 transformer from BFL-native single-file checkpoint"
496            );
497            let backend =
498                crate::loader::single_file_backend::SingleFileBackend::from_flux2_singlefile(
499                    &self.base.paths.transformer,
500                    cfg,
501                )?;
502            let backend: Box<dyn candle_nn::var_builder::SimpleBackend> = Box::new(backend);
503            if self.offload && !has_lora {
504                let flux_vb = candle_nn::VarBuilder::from_backend(backend, gpu_dtype, Device::Cpu);
505                return Ok((
506                    Flux2TransformerWrapper::Offloaded(
507                        super::transformer::OffloadedFlux2Transformer::new(cfg, flux_vb, device)?,
508                    ),
509                    "Loading Flux.2 transformer (offload, BF16, single-file remap)",
510                ));
511            }
512            let backend = if has_lora {
513                let adapters =
514                    super::lora::load_lora_adapters(&self.pending_loras, &self.base.progress)?;
515                let specs: Vec<super::lora::Flux2LoraSpec<'_>> = adapters
516                    .iter()
517                    .zip(self.pending_loras.iter())
518                    .map(|(adapter, w)| super::lora::Flux2LoraSpec {
519                        adapter: adapter.as_ref(),
520                        scale: w.scale,
521                        path_hash: super::lora::lora_path_hash(&w.path),
522                    })
523                    .collect();
524                super::lora::wrap_backend_with_lora(
525                    backend,
526                    &specs,
527                    super::lora::Flux2KeySpace::Diffusers,
528                    &self.base.progress,
529                    None,
530                )?
531            } else {
532                backend
533            };
534            let flux_vb = candle_nn::VarBuilder::from_backend(backend, gpu_dtype, device.clone());
535            let label = if has_lora {
536                "Loading Flux.2 transformer (GPU, BF16, single-file remap + LoRA)"
537            } else {
538                "Loading Flux.2 transformer (GPU, BF16, single-file remap)"
539            };
540            Ok((
541                Flux2TransformerWrapper::BF16(super::transformer::Flux2Transformer::new(
542                    cfg, flux_vb,
543                )?),
544                label,
545            ))
546        } else {
547            let xformer_paths = if !self.base.paths.transformer_shards.is_empty() {
548                self.base.paths.transformer_shards.clone()
549            } else {
550                vec![self.base.paths.transformer.clone()]
551            };
552            let (flux_vb, offloaded_label) = if has_lora {
553                // Build our own mmap-backed SimpleBackend so we can wrap with
554                // `Flux2LoraBackend`. The progress reporting drops to a single
555                // info line — the legacy progress bar is keyed to candle's
556                // internal mmap path.
557                use candle_core::safetensors::MmapedSafetensors;
558                let path_refs: Vec<&std::path::Path> =
559                    xformer_paths.iter().map(|p| p.as_path()).collect();
560                let st = unsafe { MmapedSafetensors::multi(&path_refs)? };
561                struct MmapBackend {
562                    st: MmapedSafetensors,
563                }
564                impl candle_nn::var_builder::SimpleBackend for MmapBackend {
565                    fn get(
566                        &self,
567                        _s: candle_core::Shape,
568                        name: &str,
569                        _h: candle_nn::Init,
570                        dtype: DType,
571                        dev: &Device,
572                    ) -> candle_core::Result<Tensor> {
573                        let t = self.st.load(name, dev)?;
574                        if t.dtype() != dtype {
575                            t.to_dtype(dtype)
576                        } else {
577                            Ok(t)
578                        }
579                    }
580                    fn get_unchecked(
581                        &self,
582                        name: &str,
583                        dtype: DType,
584                        dev: &Device,
585                    ) -> candle_core::Result<Tensor> {
586                        let t = self.st.load(name, dev)?;
587                        if t.dtype() != dtype {
588                            t.to_dtype(dtype)
589                        } else {
590                            Ok(t)
591                        }
592                    }
593                    fn contains_tensor(&self, name: &str) -> bool {
594                        self.st.get(name).is_ok()
595                    }
596                }
597                let inner: Box<dyn candle_nn::var_builder::SimpleBackend> =
598                    Box::new(MmapBackend { st });
599                let adapters =
600                    super::lora::load_lora_adapters(&self.pending_loras, &self.base.progress)?;
601                let specs: Vec<super::lora::Flux2LoraSpec<'_>> = adapters
602                    .iter()
603                    .zip(self.pending_loras.iter())
604                    .map(|(adapter, w)| super::lora::Flux2LoraSpec {
605                        adapter: adapter.as_ref(),
606                        scale: w.scale,
607                        path_hash: super::lora::lora_path_hash(&w.path),
608                    })
609                    .collect();
610                let wrapped = super::lora::wrap_backend_with_lora(
611                    inner,
612                    &specs,
613                    super::lora::Flux2KeySpace::Diffusers,
614                    &self.base.progress,
615                    None,
616                )?;
617                (
618                    candle_nn::VarBuilder::from_backend(wrapped, gpu_dtype, device.clone()),
619                    None,
620                )
621            } else if self.offload {
622                (
623                    crate::weight_loader::load_safetensors_with_progress(
624                        &xformer_paths,
625                        gpu_dtype,
626                        &Device::Cpu,
627                        "Flux.2 transformer (offload blocks)",
628                        &self.base.progress,
629                    )?,
630                    Some("Loading Flux.2 transformer (offload, BF16)"),
631                )
632            } else {
633                (
634                    crate::weight_loader::load_safetensors_with_progress(
635                        &xformer_paths,
636                        gpu_dtype,
637                        device,
638                        "Flux.2 transformer",
639                        &self.base.progress,
640                    )?,
641                    None,
642                )
643            };
644            if let Some(label) = offloaded_label {
645                return Ok((
646                    Flux2TransformerWrapper::Offloaded(
647                        super::transformer::OffloadedFlux2Transformer::new(cfg, flux_vb, device)?,
648                    ),
649                    label,
650                ));
651            }
652            let label = if has_lora {
653                "Loading Flux.2 transformer (GPU, BF16 + LoRA)"
654            } else {
655                "Loading Flux.2 transformer (GPU, BF16)"
656            };
657            Ok((
658                Flux2TransformerWrapper::BF16(super::transformer::Flux2Transformer::new(
659                    cfg, flux_vb,
660                )?),
661                label,
662            ))
663        }
664    }
665
666    /// `true` when the transformer is a single `.safetensors` file whose
667    /// tensor keys are BFL-native (`model.diffusion_model.*`). Returns
668    /// `true` for `BflNative`, `BflNativeRoot`, and `Nvfp4` — all route
669    /// through `SingleFileBackend` (the NVFP4 path dequantises FP4×FP8
670    /// blocks to FP8-E4M3 on lookup; the BFL variants pass tensors
671    /// through directly). Sharded loads (HF diffusers layout) and any
672    /// non-safetensors path skip this detection. Header-peeks the file
673    /// once per load — a few KB read.
674    fn is_bfl_native_single_file(&self) -> bool {
675        if !self.base.paths.transformer_shards.is_empty() {
676            return false;
677        }
678        let path = &self.base.paths.transformer;
679        let is_safetensors = path
680            .extension()
681            .and_then(|e| e.to_str())
682            .is_some_and(|e| e.eq_ignore_ascii_case("safetensors"));
683        if !is_safetensors {
684            return false;
685        }
686        matches!(
687            super::single_file::detect_format(path),
688            Ok(super::single_file::Flux2SingleFileFormat::BflNative)
689                | Ok(super::single_file::Flux2SingleFileFormat::BflNativeRoot)
690                | Ok(super::single_file::Flux2SingleFileFormat::Nvfp4)
691        )
692    }
693
694    /// Reload transformer using `&mut self` — called before the main `loaded` borrow
695    /// to avoid borrow conflicts.
696    fn reload_transformer_if_needed(&mut self) -> Result<()> {
697        let needs_reload = self
698            .base
699            .loaded
700            .as_ref()
701            .map(|l| l.transformer.is_none())
702            .unwrap_or(false);
703
704        if needs_reload {
705            let cfg = self.resolve_config();
706            self.base
707                .progress
708                .stage_start("Reloading Flux.2 transformer");
709            let reload_start = Instant::now();
710            let (transformer, _label) = self.load_transformer(
711                &cfg,
712                self.base.loaded.as_ref().unwrap().dtype,
713                &self.base.loaded.as_ref().unwrap().device.clone(),
714            )?;
715            self.base.loaded.as_mut().unwrap().transformer = Some(transformer);
716            self.base
717                .progress
718                .stage_done("Reloading Flux.2 transformer", reload_start.elapsed());
719        }
720        Ok(())
721    }
722
723    fn should_delay_transformer_reload_for_prompt_encode(
724        load_strategy: LoadStrategy,
725        transformer_loaded: bool,
726    ) -> bool {
727        load_strategy == LoadStrategy::Eager && !transformer_loaded
728    }
729
730    /// Get text encoder file paths (shards or single file).
731    fn text_encoder_paths(&self) -> Vec<std::path::PathBuf> {
732        if !self.base.paths.text_encoder_files.is_empty() {
733            self.base.paths.text_encoder_files.clone()
734        } else {
735            // Fallback: t5_encoder field is reused as the generic text encoder path
736            self.base
737                .paths
738                .t5_encoder
739                .as_ref()
740                .map(|p| vec![p.clone()])
741                .unwrap_or_default()
742        }
743    }
744
745    /// Encode a prompt with the Qwen3 text encoder, extracting hidden states from
746    /// layers 9, 18, 27 and stacking them to produce the context embedding.
747    ///
748    /// - Klein-4B: Qwen3-4B (hidden=2560) → stacked dim = 2560 * 3 = 7680
749    /// - Klein-9B: Qwen3-8B (hidden=4096) → stacked dim = 4096 * 3 = 12288
750    ///
751    /// Both use the same 36-layer Qwen3 architecture. Layers 9, 18, 27 correspond
752    /// to roughly 1/4, 1/2, 3/4 depth.
753    const QWEN3_HIDDEN_LAYERS: [usize; 3] = [9, 18, 27];
754
755    fn encode_and_stack(
756        encoder: &mut encoders::qwen3::Qwen3Encoder,
757        prompt: &str,
758        target_device: &Device,
759        target_dtype: DType,
760    ) -> Result<Tensor> {
761        // Extract hidden states from layers 9, 18, 27 and stack to (B, seq, 7680)
762        let (stacked, _token_count) = encoder.encode_with_layers(
763            prompt,
764            target_device,
765            target_dtype,
766            &Self::QWEN3_HIDDEN_LAYERS,
767        )?;
768        Ok(stacked)
769    }
770
771    fn encode_prompt_cached(
772        progress: &ProgressReporter,
773        prompt_cache: &Mutex<LruCache<String, CachedTensor>>,
774        encoder: &mut encoders::qwen3::Qwen3Encoder,
775        prompt: &str,
776        target_device: &Device,
777        target_dtype: DType,
778    ) -> Result<Tensor> {
779        let cache_key = prompt_text_key(prompt);
780        let (txt_emb, cache_hit) = get_or_insert_cached_tensor(
781            prompt_cache,
782            cache_key,
783            target_device,
784            target_dtype,
785            || {
786                progress.stage_start("Encoding prompt (Qwen3)");
787                let encode_start = Instant::now();
788                let txt_emb = Self::encode_and_stack(encoder, prompt, target_device, target_dtype)?;
789                progress.stage_done("Encoding prompt (Qwen3)", encode_start.elapsed());
790                Ok(txt_emb)
791            },
792        )?;
793        if cache_hit {
794            progress.cache_hit("prompt conditioning");
795        }
796        Ok(txt_emb)
797    }
798
799    /// Load all model components (Eager mode).
800    ///
801    /// On error, `self.base.loaded` remains `None` — all components are assembled into
802    /// local variables and only stored in `self.base.loaded` on success, so partial loads
803    /// cannot leave the engine in an inconsistent state.
804    ///
805    /// GGUF variants keep weights quantized in VRAM via QMatMul (~6GB for Q4 9B),
806    /// so both Klein-4B and Klein-9B fit comfortably in eager mode on 24GB GPUs.
807    pub fn load(&mut self) -> Result<()> {
808        if self.base.loaded.is_some() {
809            return Ok(());
810        }
811
812        // Sequential mode defers loading to generate_sequential()
813        if self.base.load_strategy == LoadStrategy::Sequential {
814            return Ok(());
815        }
816
817        tracing::info!(model = %self.base.model_name, "loading Flux.2 Klein model components...");
818
819        let text_tokenizer_path = self.validate_paths()?;
820
821        let cpu = Device::Cpu;
822        let transformer_ref = effective_device_ref(
823            self.pending_placement.as_ref(),
824            |adv| Some(adv.transformer),
825            false,
826        );
827        let device = crate::device::resolve_device(Some(transformer_ref), || {
828            crate::device::create_device(self.base.gpu_ordinal, &self.base.progress)
829        })?;
830        let gpu_dtype = crate::engine::gpu_dtype(&device);
831
832        tracing::info!("GPU device: {:?}, GPU dtype: {:?}", device, gpu_dtype);
833
834        // --- Load transformer on GPU first ---
835        let flux2_cfg = self.resolve_config();
836        let xformer_stage = Instant::now();
837        let (transformer, xformer_label) = self.load_transformer(&flux2_cfg, gpu_dtype, &device)?;
838        self.base
839            .progress
840            .stage_done(xformer_label, xformer_stage.elapsed());
841
842        // --- Load VAE on GPU ---
843        let vae_ref =
844            effective_device_ref(self.pending_placement.as_ref(), |adv| Some(adv.vae), false);
845        let vae_device = crate::device::resolve_device(Some(vae_ref), || Ok(device.clone()))?;
846        self.base.progress.stage_start("Loading VAE (GPU)");
847        let vae_stage = Instant::now();
848        tracing::info!(path = %self.base.paths.vae.display(), "loading VAE on GPU...");
849        let vae_cfg = Flux2VaeConfig::klein();
850        // Resolve VAE precision once at load — see LoadedFlux2::vae_dtype.
851        let vae_dtype = crate::device::resolve_vae_dtype(gpu_dtype);
852        let vae_vb = self.load_vae_var_builder(vae_dtype, &vae_device, "VAE")?;
853        let vae = Flux2AutoEncoder::new(&vae_cfg, vae_vb)?;
854        self.base
855            .progress
856            .stage_done("Loading VAE (GPU)", vae_stage.elapsed());
857        tracing::info!("VAE loaded on GPU");
858
859        // --- Resolve and load Qwen3 text encoder ---
860        // Log the raw reading (matches `nvidia-smi`); budget the variant
861        // selection against the reserve-adjusted value.
862        let free_raw = free_vram_bytes(self.base.gpu_ordinal).unwrap_or(0);
863        let free = usable_free_vram_bytes(self.base.gpu_ordinal).unwrap_or(0);
864        if free_raw > 0 {
865            self.base.progress.info(&format!(
866                "Free VRAM after transformer+VAE: {}",
867                fmt_gb(free_raw)
868            ));
869        }
870
871        self.base.progress.stage_start("Selecting Qwen3 encoder");
872        let resolve_start = Instant::now();
873        let qwen3_size = self.qwen3_size();
874        let (encoder_paths, is_gguf, on_gpu, device_label) = {
875            let bf16_paths = self.text_encoder_paths();
876            let have_bf16 = !bf16_paths.is_empty() && bf16_paths.iter().all(|p| p.exists());
877            crate::encoders::variant_resolution::resolve_qwen3_variant(
878                &self.base.progress,
879                self.qwen3_variant.as_deref(),
880                &device,
881                free,
882                &bf16_paths,
883                have_bf16,
884                true,
885                qwen3_size,
886            )?
887        };
888        self.base
889            .progress
890            .stage_done("Selecting Qwen3 encoder", resolve_start.elapsed());
891
892        let qwen3_ref = effective_device_ref(self.pending_placement.as_ref(), |adv| adv.qwen, true);
893        let auto_enc_device = if on_gpu { device.clone() } else { cpu.clone() };
894        let enc_device_owned =
895            crate::device::resolve_device(Some(qwen3_ref), || Ok(auto_enc_device.clone()))?;
896        let enc_device = &enc_device_owned;
897        let on_gpu = !enc_device.is_cpu();
898        let enc_dtype = if on_gpu { gpu_dtype } else { DType::F32 };
899        let bf16_cfg = self.qwen3_bf16_config();
900
901        let enc_stage_label = format!("Loading Qwen3 encoder ({device_label})");
902        self.base.progress.stage_start(&enc_stage_label);
903        let enc_stage = Instant::now();
904        let text_tokenizer = self.load_text_tokenizer(&text_tokenizer_path)?;
905
906        let text_encoder = if is_gguf {
907            encoders::qwen3::Qwen3Encoder::load_gguf_with_tokenizer(
908                &encoder_paths[0],
909                &text_tokenizer_path,
910                Some(text_tokenizer),
911                enc_device,
912                &bf16_cfg,
913            )?
914        } else {
915            encoders::qwen3::Qwen3Encoder::load_bf16_with_tokenizer(
916                &encoder_paths,
917                &text_tokenizer_path,
918                Some(text_tokenizer),
919                enc_device,
920                enc_dtype,
921                &bf16_cfg,
922                &self.base.progress,
923            )?
924        };
925        self.base
926            .progress
927            .stage_done(&enc_stage_label, enc_stage.elapsed());
928        tracing::info!(device = %device_label, "Qwen3 encoder loaded");
929
930        self.base.loaded = Some(LoadedFlux2 {
931            transformer: Some(transformer),
932            text_encoder,
933            vae,
934            device,
935            dtype: gpu_dtype,
936            vae_dtype,
937        });
938
939        tracing::info!(model = %self.base.model_name, "all Flux.2 model components loaded successfully");
940        Ok(())
941    }
942
943    /// Generate an image using sequential loading strategy (load-use-drop).
944    fn generate_sequential(&mut self, req: &GenerateRequest) -> Result<GenerateResponse> {
945        let text_tokenizer_path = self.validate_paths()?;
946        let is_gguf = self.is_gguf_transformer();
947
948        match flux2_offload_decision(self.offload, is_gguf, !self.pending_loras.is_empty()) {
949            Flux2OffloadDecision::Disabled => {}
950            Flux2OffloadDecision::Unsupported(reason) => bail!("{reason}"),
951            Flux2OffloadDecision::Selected => {}
952        }
953
954        if let Some(warning) = check_memory_budget(&self.base.paths, LoadStrategy::Sequential) {
955            self.base.progress.info(&warning);
956        }
957
958        let transformer_ref = effective_device_ref(
959            self.pending_placement.as_ref(),
960            |adv| Some(adv.transformer),
961            false,
962        );
963        let device = crate::device::resolve_device(Some(transformer_ref), || {
964            crate::device::create_device(self.base.gpu_ordinal, &self.base.progress)
965        })?;
966        let gpu_dtype = crate::engine::gpu_dtype(&device);
967
968        let start = Instant::now();
969        let seed = req.seed.unwrap_or_else(rand_seed);
970
971        let width = req.width as usize;
972        let height = req.height as usize;
973
974        tracing::info!(
975            prompt = %req.prompt,
976            seed, width, height,
977            steps = req.steps,
978            "starting sequential Flux.2 generation"
979        );
980
981        self.base
982            .progress
983            .info("Using sequential loading (load-use-drop) to minimize peak memory");
984
985        // --- Phase 1: Qwen3 text encoding ---
986        // Check prompt cache first — skip encoder load entirely on cache hit.
987        // This saves ~1-5s per batch image (encoder load + VRAM allocation).
988        let cache_key = prompt_text_key(&req.prompt);
989        let txt_emb = if let Some(tensor) =
990            restore_cached_tensor(&self.prompt_cache, &cache_key, &device, gpu_dtype)?
991        {
992            self.base.progress.cache_hit("prompt conditioning");
993            tensor
994        } else {
995            // Reserve-adjusted reading drives the Qwen3 variant selection.
996            let free = usable_free_vram_bytes(self.base.gpu_ordinal).unwrap_or(0);
997            self.base.progress.stage_start("Selecting Qwen3 encoder");
998            let resolve_start = Instant::now();
999            let qwen3_size = self.qwen3_size();
1000            let (encoder_paths, is_gguf, on_gpu, device_label) = {
1001                let bf16_paths = self.text_encoder_paths();
1002                let have_bf16 = !bf16_paths.is_empty() && bf16_paths.iter().all(|p| p.exists());
1003                crate::encoders::variant_resolution::resolve_qwen3_variant(
1004                    &self.base.progress,
1005                    self.qwen3_variant.as_deref(),
1006                    &device,
1007                    free,
1008                    &bf16_paths,
1009                    have_bf16,
1010                    true,
1011                    qwen3_size,
1012                )?
1013            };
1014            self.base
1015                .progress
1016                .stage_done("Selecting Qwen3 encoder", resolve_start.elapsed());
1017
1018            let qwen3_ref =
1019                effective_device_ref(self.pending_placement.as_ref(), |adv| adv.qwen, true);
1020            let auto_enc_device = if on_gpu { device.clone() } else { Device::Cpu };
1021            let enc_device_owned =
1022                crate::device::resolve_device(Some(qwen3_ref), || Ok(auto_enc_device.clone()))?;
1023            let enc_device = &enc_device_owned;
1024            let on_gpu = !enc_device.is_cpu();
1025            let enc_dtype = if on_gpu { gpu_dtype } else { DType::F32 };
1026            let bf16_cfg = self.qwen3_bf16_config();
1027
1028            // Pre-flight memory check
1029            let enc_size: u64 = encoder_paths
1030                .iter()
1031                .filter_map(|p| std::fs::metadata(p).ok().map(|m| m.len()))
1032                .sum();
1033            let enc_activation_budget = crate::device::activation_bytes(
1034                req.width,
1035                req.height,
1036                1,
1037                crate::device::dtype_bytes(enc_dtype),
1038                crate::device::ActivationFamily::SmallTransformer,
1039            );
1040            preflight_memory_check("Qwen3 encoder", enc_size, enc_activation_budget)?;
1041            if let Some(status) = memory_status_string() {
1042                self.base.progress.info(&status);
1043            }
1044
1045            let enc_stage_label = format!("Loading Qwen3 encoder ({device_label})");
1046            self.base.progress.stage_start(&enc_stage_label);
1047            let enc_stage = Instant::now();
1048            let text_tokenizer = self.load_text_tokenizer(&text_tokenizer_path)?;
1049
1050            let mut text_encoder = if is_gguf {
1051                encoders::qwen3::Qwen3Encoder::load_gguf_with_tokenizer(
1052                    &encoder_paths[0],
1053                    &text_tokenizer_path,
1054                    Some(text_tokenizer),
1055                    enc_device,
1056                    &bf16_cfg,
1057                )?
1058            } else {
1059                encoders::qwen3::Qwen3Encoder::load_bf16_with_tokenizer(
1060                    &encoder_paths,
1061                    &text_tokenizer_path,
1062                    Some(text_tokenizer),
1063                    enc_device,
1064                    enc_dtype,
1065                    &bf16_cfg,
1066                    &self.base.progress,
1067                )?
1068            };
1069            self.base
1070                .progress
1071                .stage_done(&enc_stage_label, enc_stage.elapsed());
1072
1073            let txt_emb = Self::encode_prompt_cached(
1074                &self.base.progress,
1075                &self.prompt_cache,
1076                &mut text_encoder,
1077                &req.prompt,
1078                &device,
1079                gpu_dtype,
1080            )?;
1081
1082            // Drop text encoder to free memory
1083            drop(text_encoder);
1084            self.base.progress.info("Freed Qwen3 encoder");
1085            tracing::info!("Qwen3 encoder dropped (sequential mode)");
1086
1087            txt_emb
1088        };
1089
1090        let latent_h = height.div_ceil(8);
1091        let latent_w = width.div_ceil(8);
1092
1093        // Pre-compute timestep schedule (needed before mixing for img2img)
1094        let image_seq_len = (height / 16) * (width / 16);
1095        let mut timesteps = sampling::get_schedule(req.steps as usize, image_seq_len);
1096
1097        if req.source_image.is_some() {
1098            let (trimmed, start_index) =
1099                crate::img2img::trim_schedule_tail(&timesteps, req.steps as usize, req.strength);
1100            timesteps = trimmed;
1101            tracing::info!(
1102                strength = req.strength,
1103                start_index,
1104                start_timestep = timesteps[0],
1105                schedule = ?timesteps,
1106                remaining_steps = timesteps.len().saturating_sub(1),
1107                "img2img: truncated schedule from strength"
1108            );
1109        }
1110
1111        // Generate noise / encode source image for img2img. Source-image
1112        // requests pre-encode in a VAE-only phase, then drop VAE before the
1113        // transformer load. Klein-9B BF16 cannot keep transformer+VAE
1114        // co-resident on 24 GB cards.
1115        let (img, inpaint_ctx) = if let Some(ref source_bytes) = req.source_image {
1116            let start_t = timesteps[0];
1117            let (vae, _vae_dtype) = self.load_sequential_vae(&device, gpu_dtype)?;
1118
1119            self.base
1120                .progress
1121                .stage_start("Encoding source image (VAE)");
1122            let encode_start = Instant::now();
1123            let source_tensor = crate::img_utils::decode_source_image(
1124                source_bytes,
1125                req.width,
1126                req.height,
1127                Self::img2img_source_normalize_range(),
1128                &device,
1129                gpu_dtype,
1130            )?;
1131            let encoded = vae.encode(&source_tensor)?;
1132            self.base
1133                .progress
1134                .stage_done("Encoding source image (VAE)", encode_start.elapsed());
1135
1136            let prepared = crate::img2img::prepare_flow_match_img2img(
1137                &encoded,
1138                seed,
1139                &[1, 32, latent_h, latent_w],
1140                start_t,
1141                req.mask_image.as_deref(),
1142                latent_h,
1143                latent_w,
1144                &device,
1145                gpu_dtype,
1146            )?;
1147            drop(vae);
1148            drop(encoded);
1149            drop(source_tensor);
1150            device.synchronize()?;
1151            self.base.progress.info("Freed VAE after source encoding");
1152            (prepared.initial_latents, prepared.inpaint_ctx)
1153        } else {
1154            let img = crate::engine::seeded_randn(
1155                seed,
1156                &[1, 32, latent_h, latent_w],
1157                &device,
1158                gpu_dtype,
1159            )?;
1160            (img, None)
1161        };
1162
1163        let state = Flux2State::new(&txt_emb, &img)?;
1164        let inpaint_ctx = inpaint_ctx
1165            .as_ref()
1166            .map(crate::img2img::pack_flux_inpaint_context)
1167            .transpose()?;
1168
1169        // --- Phase 2: Load transformer, denoise ---
1170        let xformer_size = std::fs::metadata(&self.base.paths.transformer)
1171            .map(|m| m.len())
1172            .unwrap_or(0);
1173        let xformer_activation_budget = crate::device::activation_bytes(
1174            req.width,
1175            req.height,
1176            1,
1177            crate::device::dtype_bytes(gpu_dtype),
1178            crate::device::ActivationFamily::Flux2Dit,
1179        );
1180        preflight_memory_check(
1181            "Flux.2 transformer",
1182            xformer_size,
1183            xformer_activation_budget,
1184        )?;
1185        if let Some(status) = memory_status_string() {
1186            self.base.progress.info(&status);
1187        }
1188
1189        let flux2_cfg = self.resolve_config();
1190        let xformer_stage = Instant::now();
1191        let (transformer, xformer_label) = self.load_transformer(&flux2_cfg, gpu_dtype, &device)?;
1192        self.base
1193            .progress
1194            .stage_done(xformer_label, xformer_stage.elapsed());
1195
1196        let denoise_label = format!("Denoising ({} steps)", timesteps.len().saturating_sub(1));
1197        self.base.progress.stage_start(&denoise_label);
1198        let denoise_start = Instant::now();
1199
1200        let img = transformer.denoise(
1201            &state.img,
1202            &state.img_ids,
1203            &state.txt,
1204            &state.txt_ids,
1205            &state.vec,
1206            &timesteps,
1207            req.guidance,
1208            &self.base.progress,
1209            inpaint_ctx.as_ref(),
1210        )?;
1211
1212        let img = sampling::unpack(&img, height, width)?;
1213
1214        self.base
1215            .progress
1216            .stage_done(&denoise_label, denoise_start.elapsed());
1217
1218        // Drop transformer + state to free memory for VAE decode
1219        drop(inpaint_ctx);
1220        drop(transformer);
1221        self.base.progress.info("Freed Flux.2 transformer");
1222        drop(state);
1223        drop(txt_emb);
1224        device.synchronize()?;
1225        tracing::info!("Transformer dropped (sequential mode), decoding VAE...");
1226
1227        let (vae, vae_dtype) = self.load_sequential_vae(&device, gpu_dtype)?;
1228
1229        // --- Phase 3: VAE decode ---
1230        self.base.progress.stage_start("VAE decode");
1231        let vae_decode_start = Instant::now();
1232        // DEBUG: dump pre-VAE latent (B, 32, H, W) when MOLD_FLUX2_DUMP_LATENT is set.
1233        if let Ok(dump_path) = std::env::var("MOLD_FLUX2_DUMP_LATENT") {
1234            let latent_f32 = img
1235                .to_dtype(DType::F32)?
1236                .to_device(&candle_core::Device::Cpu)?;
1237            let dims = latent_f32.dims().to_vec();
1238            let v: Vec<f32> = latent_f32.flatten_all()?.to_vec1()?;
1239            let mut bytes = Vec::with_capacity(8 * 4 + v.len() * 4);
1240            bytes.extend_from_slice(&(dims.len() as u32).to_le_bytes());
1241            for d in &dims {
1242                bytes.extend_from_slice(&(*d as u32).to_le_bytes());
1243            }
1244            for x in &v {
1245                bytes.extend_from_slice(&x.to_le_bytes());
1246            }
1247            std::fs::write(&dump_path, &bytes)?;
1248            tracing::info!(path = %dump_path, dims = ?dims, "dumped pre-VAE latent");
1249        }
1250        let img_for_vae = img.to_dtype(vae_dtype)?;
1251        let device_for_sync = device.clone();
1252        let img = crate::vae_tiling::decode_with_oom_fallback(
1253            &img_for_vae,
1254            |latents| vae.decode(latents).map_err(Into::into),
1255            || {
1256                if let Err(e) = device_for_sync.synchronize() {
1257                    tracing::warn!(
1258                        "FLUX2 (sequential) device.synchronize() after VAE OOM failed: {e}"
1259                    );
1260                }
1261            },
1262        )?;
1263
1264        let img = ((img.clamp(-1f32, 1f32)? + 1.0)? * 127.5)?.to_dtype(DType::U8)?;
1265        let img = img.i(0)?;
1266
1267        self.base
1268            .progress
1269            .stage_done("VAE decode", vae_decode_start.elapsed());
1270
1271        let output_metadata = build_output_metadata(req, seed, None);
1272        let image_bytes = encode_image(
1273            &img,
1274            req.resolved_output_format(),
1275            req.width,
1276            req.height,
1277            output_metadata.as_ref(),
1278        )?;
1279
1280        let generation_time_ms = start.elapsed().as_millis() as u64;
1281        tracing::info!(generation_time_ms, seed, "sequential generation complete");
1282
1283        Ok(GenerateResponse {
1284            images: vec![ImageData {
1285                data: image_bytes,
1286                format: req.resolved_output_format(),
1287                width: req.width,
1288                height: req.height,
1289                index: 0,
1290            }],
1291            generation_time_ms,
1292            model: req.model.clone(),
1293            seed_used: seed,
1294            video: None,
1295            gpu: None,
1296        })
1297    }
1298}
1299
1300// ---------------------------------------------------------------------------
1301// InferenceEngine implementation
1302// ---------------------------------------------------------------------------
1303
1304impl Flux2Engine {
1305    fn generate_inner(&mut self, req: &GenerateRequest) -> Result<GenerateResponse> {
1306        if req.scheduler.is_some() {
1307            tracing::warn!(
1308                "scheduler selection not supported for Flux.2 (flow-matching), ignoring"
1309            );
1310        }
1311        if req.guidance != 0.0 {
1312            tracing::debug!(
1313                guidance = req.guidance,
1314                "Flux.2 Klein is distilled — guidance value is ignored (no guidance embedding)"
1315            );
1316        }
1317        // Sequential mode: load-use-drop each component
1318        if self.uses_sequential_generate_path() {
1319            return self.generate_sequential(req);
1320        }
1321
1322        // Eager mode: use pre-loaded components. After a previous request we
1323        // intentionally drop the transformer before VAE decode, but the VAE
1324        // and Qwen3 shell remain resident. In that warm state, reloading the
1325        // transformer before prompt encoding recreates the highest peak
1326        // (transformer + Qwen3) and can OOM on 24 GB cards when queued
1327        // requests arrive back-to-back. Encode/drop Qwen3 first, then reload
1328        // the transformer for denoising.
1329        let delay_transformer_reload = self.base.loaded.as_ref().is_some_and(|loaded| {
1330            Self::should_delay_transformer_reload_for_prompt_encode(
1331                self.base.load_strategy,
1332                loaded.transformer.is_some(),
1333            )
1334        });
1335        if delay_transformer_reload {
1336            tracing::info!(
1337                "delaying Flux.2 transformer reload until after prompt encode to reduce peak VRAM"
1338            );
1339        }
1340
1341        let start = Instant::now();
1342        let seed = req.seed.unwrap_or_else(rand_seed);
1343
1344        let width = req.width as usize;
1345        let height = req.height as usize;
1346
1347        tracing::info!(
1348            prompt = %req.prompt,
1349            seed, width, height,
1350            steps = req.steps,
1351            "starting Flux.2 generation"
1352        );
1353
1354        // 1. Encode prompt with Qwen3 (check cache first to avoid unnecessary reload)
1355        let txt_emb = {
1356            let progress = &self.base.progress;
1357            let loaded = self
1358                .base
1359                .loaded
1360                .as_mut()
1361                .ok_or_else(|| anyhow::anyhow!("model not loaded — call load() first"))?;
1362            let cache_key = prompt_text_key(&req.prompt);
1363            if let Some(tensor) =
1364                restore_cached_tensor(&self.prompt_cache, &cache_key, &loaded.device, loaded.dtype)?
1365            {
1366                progress.cache_hit("prompt conditioning");
1367                tensor
1368            } else {
1369                // Cache miss — restore encoder if it was dropped or parked after
1370                // a previous generation.
1371                if loaded.text_encoder.model.is_none() {
1372                    let label = if loaded.text_encoder.is_parked() {
1373                        "Unparking Qwen3 encoder (CPU→GPU)"
1374                    } else {
1375                        "Reloading Qwen3 encoder"
1376                    };
1377                    progress.stage_start(label);
1378                    let reload_start = Instant::now();
1379                    if loaded.text_encoder.is_parked() {
1380                        loaded.text_encoder.unpark_to_gpu(progress)?;
1381                    } else {
1382                        loaded.text_encoder.reload(progress)?;
1383                    }
1384                    progress.stage_done(label, reload_start.elapsed());
1385                }
1386
1387                let txt_emb = Self::encode_prompt_cached(
1388                    progress,
1389                    &self.prompt_cache,
1390                    &mut loaded.text_encoder,
1391                    &req.prompt,
1392                    &loaded.device,
1393                    loaded.dtype,
1394                )?;
1395                tracing::info!("Qwen3 encoding complete");
1396
1397                // Free GPU VRAM for denoising. With `MOLD_KEEP_TE_RAM=1` and the
1398                // BF16 encoder, parameters move to host RAM instead of being
1399                // released — saves ~10 s on the next request. GGUF and Metal
1400                // flow through the original drop path.
1401                if loaded.text_encoder.on_gpu || loaded.device.is_metal() {
1402                    let park_mode = crate::device::keep_te_in_ram()
1403                        && !loaded.device.is_metal()
1404                        && !loaded.text_encoder.is_quantized;
1405                    if park_mode {
1406                        loaded.text_encoder.park_to_cpu()?;
1407                        tracing::info!(
1408                            on_gpu = loaded.text_encoder.on_gpu,
1409                            "Qwen3 encoder parked to CPU host RAM"
1410                        );
1411                    } else {
1412                        loaded.text_encoder.drop_weights();
1413                        tracing::info!(
1414                            on_gpu = loaded.text_encoder.on_gpu,
1415                            "Qwen3 encoder dropped to free memory for denoising"
1416                        );
1417                    }
1418                }
1419
1420                txt_emb
1421            }
1422        };
1423
1424        self.reload_transformer_if_needed()?;
1425
1426        let loaded = self
1427            .base
1428            .loaded
1429            .as_mut()
1430            .ok_or_else(|| anyhow::anyhow!("model not loaded — call load() first"))?;
1431        let progress = &self.base.progress;
1432
1433        // 2. Prepare latent space dimensions and timestep schedule
1434        let latent_h = height.div_ceil(8);
1435        let latent_w = width.div_ceil(8);
1436
1437        // Pre-compute timestep schedule (needed before mixing for img2img)
1438        let image_seq_len = (height / 16) * (width / 16);
1439        let mut timesteps = sampling::get_schedule(req.steps as usize, image_seq_len);
1440
1441        if req.source_image.is_some() {
1442            let (trimmed, start_index) =
1443                crate::img2img::trim_schedule_tail(&timesteps, req.steps as usize, req.strength);
1444            timesteps = trimmed;
1445            tracing::info!(
1446                strength = req.strength,
1447                start_index,
1448                start_timestep = timesteps[0],
1449                schedule = ?timesteps,
1450                remaining_steps = timesteps.len().saturating_sub(1),
1451                "img2img: truncated schedule from strength"
1452            );
1453        }
1454
1455        // 3. Generate noise / encode source image for img2img
1456        let (img, inpaint_ctx) = if let Some(ref source_bytes) = req.source_image {
1457            let start_t = timesteps[0];
1458
1459            progress.stage_start("Encoding source image (VAE)");
1460            let encode_start = Instant::now();
1461            let source_tensor = crate::img_utils::decode_source_image(
1462                source_bytes,
1463                req.width,
1464                req.height,
1465                Self::img2img_source_normalize_range(),
1466                &loaded.device,
1467                loaded.vae_dtype,
1468            )?;
1469            let encoded = loaded.vae.encode(&source_tensor)?;
1470            progress.stage_done("Encoding source image (VAE)", encode_start.elapsed());
1471
1472            let prepared = crate::img2img::prepare_flow_match_img2img(
1473                &encoded,
1474                seed,
1475                &[1, 32, latent_h, latent_w],
1476                start_t,
1477                req.mask_image.as_deref(),
1478                latent_h,
1479                latent_w,
1480                &loaded.device,
1481                loaded.dtype,
1482            )?;
1483            (prepared.initial_latents, prepared.inpaint_ctx)
1484        } else {
1485            let img = crate::engine::seeded_randn(
1486                seed,
1487                &[1, 32, latent_h, latent_w],
1488                &loaded.device,
1489                loaded.dtype,
1490            )?;
1491            (img, None)
1492        };
1493
1494        // 4. Build sampling state
1495        let state = Flux2State::new(&txt_emb, &img)?;
1496        let inpaint_ctx = inpaint_ctx
1497            .as_ref()
1498            .map(crate::img2img::pack_flux_inpaint_context)
1499            .transpose()?;
1500
1501        let denoise_label = format!("Denoising ({} steps)", timesteps.len().saturating_sub(1));
1502        progress.stage_start(&denoise_label);
1503        let denoise_start = Instant::now();
1504        tracing::info!(
1505            steps = timesteps.len().saturating_sub(1),
1506            "running denoising loop..."
1507        );
1508
1509        // 5. Denoise
1510        let transformer = loaded
1511            .transformer
1512            .as_ref()
1513            .ok_or_else(|| anyhow::anyhow!("transformer not loaded"))?;
1514        let img = transformer.denoise(
1515            &state.img,
1516            &state.img_ids,
1517            &state.txt,
1518            &state.txt_ids,
1519            &state.vec,
1520            &timesteps,
1521            req.guidance,
1522            progress,
1523            inpaint_ctx.as_ref(),
1524        )?;
1525
1526        // 6. Unpack latent to spatial
1527        let img = sampling::unpack(&img, height, width)?;
1528        progress.stage_done(&denoise_label, denoise_start.elapsed());
1529        tracing::info!("denoising complete, decoding VAE...");
1530
1531        // Free denoising intermediates and transformer before VAE decode.
1532        // The transformer consumes most of VRAM — VAE decode needs that
1533        // memory for conv2d intermediates. Transformer is reloaded next generate.
1534        drop(inpaint_ctx);
1535        drop(state);
1536        drop(txt_emb);
1537        loaded.transformer = None;
1538        // Force CUDA to complete pending operations and release freed memory.
1539        // Without this, cuMemFree is asynchronous and the freed VRAM may not
1540        // be available when VAE decode allocates its conv2d intermediates.
1541        loaded.device.synchronize()?;
1542        tracing::info!("Transformer dropped to free VRAM for VAE decode");
1543
1544        // 7. Decode with VAE
1545        progress.stage_start("VAE decode");
1546        let vae_decode_start = Instant::now();
1547        // DEBUG: dump pre-VAE latent when MOLD_FLUX2_DUMP_LATENT is set.
1548        if let Ok(dump_path) = std::env::var("MOLD_FLUX2_DUMP_LATENT") {
1549            let latent_f32 = img
1550                .to_dtype(DType::F32)?
1551                .to_device(&candle_core::Device::Cpu)?;
1552            let dims = latent_f32.dims().to_vec();
1553            let v: Vec<f32> = latent_f32.flatten_all()?.to_vec1()?;
1554            let mut bytes = Vec::with_capacity(8 * 4 + v.len() * 4);
1555            bytes.extend_from_slice(&(dims.len() as u32).to_le_bytes());
1556            for d in &dims {
1557                bytes.extend_from_slice(&(*d as u32).to_le_bytes());
1558            }
1559            for x in &v {
1560                bytes.extend_from_slice(&x.to_le_bytes());
1561            }
1562            std::fs::write(&dump_path, &bytes)?;
1563            tracing::info!(path = %dump_path, dims = ?dims, "dumped pre-VAE latent (parallel)");
1564        }
1565        let img_for_vae = img.to_dtype(loaded.vae_dtype)?;
1566        let vae = &loaded.vae;
1567        let device_for_sync = loaded.device.clone();
1568        let img = crate::vae_tiling::decode_with_oom_fallback(
1569            &img_for_vae,
1570            |latents| vae.decode(latents).map_err(Into::into),
1571            || {
1572                if let Err(e) = device_for_sync.synchronize() {
1573                    tracing::warn!(
1574                        "FLUX2 (parallel) device.synchronize() after VAE OOM failed: {e}"
1575                    );
1576                }
1577            },
1578        )?;
1579
1580        // 8. Convert to u8 image: clamp to [-1, 1], map to [0, 255]
1581        let img = ((img.clamp(-1f32, 1f32)? + 1.0)? * 127.5)?.to_dtype(DType::U8)?;
1582        let img = img.i(0)?; // remove batch dim: [3, H, W]
1583
1584        progress.stage_done("VAE decode", vae_decode_start.elapsed());
1585        tracing::info!("VAE decode complete, encoding output image...");
1586
1587        // 9. Convert candle tensor to image bytes
1588        let output_metadata = build_output_metadata(req, seed, None);
1589        let image_bytes = encode_image(
1590            &img,
1591            req.resolved_output_format(),
1592            req.width,
1593            req.height,
1594            output_metadata.as_ref(),
1595        )?;
1596
1597        let generation_time_ms = start.elapsed().as_millis() as u64;
1598        tracing::info!(generation_time_ms, seed, "generation complete");
1599
1600        Ok(GenerateResponse {
1601            images: vec![ImageData {
1602                data: image_bytes,
1603                format: req.resolved_output_format(),
1604                width: req.width,
1605                height: req.height,
1606                index: 0,
1607            }],
1608            generation_time_ms,
1609            model: req.model.clone(),
1610            seed_used: seed,
1611            video: None,
1612            gpu: None,
1613        })
1614    }
1615}
1616
1617impl InferenceEngine for Flux2Engine {
1618    fn generate(&mut self, req: &GenerateRequest) -> Result<GenerateResponse> {
1619        self.pending_placement = req.placement.clone();
1620        self.pending_loras = effective_flux2_loras(req);
1621        let result = self.generate_inner(req);
1622        self.pending_placement = None;
1623        self.pending_loras.clear();
1624        result
1625    }
1626
1627    fn model_name(&self) -> &str {
1628        self.base.model_name()
1629    }
1630
1631    fn is_loaded(&self) -> bool {
1632        self.base.is_loaded()
1633    }
1634
1635    fn load(&mut self) -> Result<()> {
1636        Flux2Engine::load(self)
1637    }
1638
1639    fn unload(&mut self) {
1640        self.base.unload();
1641        clear_cache(&self.prompt_cache);
1642    }
1643
1644    fn set_on_progress(&mut self, callback: ProgressCallback) {
1645        self.base.set_on_progress(callback);
1646    }
1647
1648    fn clear_on_progress(&mut self) {
1649        self.base.clear_on_progress();
1650    }
1651
1652    fn model_paths(&self) -> Option<&mold_core::ModelPaths> {
1653        Some(&self.base.paths)
1654    }
1655}
1656
1657#[cfg(test)]
1658mod tests {
1659    use super::*;
1660    use crate::encoders::variant_resolution::Qwen3Size;
1661    use crate::engine::LoadStrategy;
1662    use crate::shared_pool::SharedPool;
1663    use mold_core::ModelPaths;
1664    use safetensors::tensor::{serialize_to_file, Dtype as SafeDtype, TensorView};
1665    use std::collections::HashMap;
1666    use std::fs;
1667    use std::path::{Path, PathBuf};
1668    use std::sync::{Arc, Mutex};
1669    use std::time::{SystemTime, UNIX_EPOCH};
1670    use tokenizers::models::bpe::BPE;
1671
1672    fn temp_test_dir(prefix: &str) -> PathBuf {
1673        let suffix = SystemTime::now()
1674            .duration_since(UNIX_EPOCH)
1675            .unwrap()
1676            .as_nanos();
1677        let dir = std::env::temp_dir().join(format!("{prefix}-{}-{suffix}", std::process::id()));
1678        fs::create_dir_all(&dir).unwrap();
1679        dir
1680    }
1681
1682    fn touch(dir: &Path, name: &str) -> PathBuf {
1683        let path = dir.join(name);
1684        fs::write(&path, b"test").unwrap();
1685        path
1686    }
1687
1688    fn flux2_model_paths(
1689        dir: &Path,
1690        transformer_name: &str,
1691        text_encoder_files: Vec<PathBuf>,
1692        t5_encoder: Option<PathBuf>,
1693    ) -> ModelPaths {
1694        ModelPaths {
1695            transformer: dir.join(transformer_name),
1696            transformer_shards: vec![],
1697            vae: dir.join("vae.safetensors"),
1698            spatial_upscaler: None,
1699            temporal_upscaler: None,
1700            distilled_lora: None,
1701            t5_encoder,
1702            clip_encoder: None,
1703            t5_tokenizer: None,
1704            clip_tokenizer: None,
1705            clip_encoder_2: None,
1706            clip_tokenizer_2: None,
1707            text_encoder_files,
1708            text_tokenizer: Some(dir.join("tokenizer.json")),
1709            decoder: None,
1710        }
1711    }
1712
1713    #[test]
1714    fn flux2_img2img_uses_minus_one_to_one_source_normalization() {
1715        assert_eq!(
1716            Flux2Engine::img2img_source_normalize_range(),
1717            crate::img_utils::NormalizeRange::MinusOneToOne
1718        );
1719    }
1720
1721    #[test]
1722    fn sequential_img2img_encodes_source_before_transformer_load() {
1723        assert!(
1724            Flux2Engine::sequential_img2img_preencodes_source(),
1725            "sequential Flux.2 img2img must not keep the VAE resident while loading the transformer"
1726        );
1727    }
1728
1729    #[test]
1730    fn eager_warm_request_delays_transformer_reload_until_after_prompt_encode() {
1731        assert!(
1732            Flux2Engine::should_delay_transformer_reload_for_prompt_encode(
1733                LoadStrategy::Eager,
1734                false
1735            ),
1736            "warm eager requests with a dropped transformer must encode/drop Qwen3 before reload"
1737        );
1738        assert!(
1739            !Flux2Engine::should_delay_transformer_reload_for_prompt_encode(
1740                LoadStrategy::Eager,
1741                true
1742            ),
1743            "fully loaded eager requests should keep the existing hot path"
1744        );
1745        assert!(
1746            !Flux2Engine::should_delay_transformer_reload_for_prompt_encode(
1747                LoadStrategy::Sequential,
1748                false
1749            ),
1750            "sequential mode already handles load-use-drop ordering"
1751        );
1752    }
1753
1754    #[test]
1755    fn flux2_model_name_controls_transformer_and_encoder_config() {
1756        let base_dir = temp_test_dir("mold-flux2-config");
1757        let standard = Flux2Engine::new(
1758            "flux2-klein:q8".to_string(),
1759            flux2_model_paths(&base_dir, "transformer.gguf", vec![], None),
1760            None,
1761            LoadStrategy::Sequential,
1762            0,
1763            false,
1764            None,
1765        );
1766        let nine_b = Flux2Engine::new(
1767            "flux2-klein-9b:q8".to_string(),
1768            flux2_model_paths(&base_dir, "transformer.gguf", vec![], None),
1769            None,
1770            LoadStrategy::Sequential,
1771            0,
1772            false,
1773            None,
1774        );
1775
1776        let standard_cfg = standard.resolve_config();
1777        let nine_b_cfg = nine_b.resolve_config();
1778
1779        assert_eq!(standard_cfg.hidden_size, 3072);
1780        assert_eq!(standard_cfg.context_in_dim, 7680);
1781        assert_eq!(standard.qwen3_size(), Qwen3Size::B4);
1782        assert_eq!(standard.qwen3_bf16_config().hidden_size, 2560);
1783
1784        assert_eq!(nine_b_cfg.hidden_size, 4096);
1785        assert_eq!(nine_b_cfg.context_in_dim, 12288);
1786        assert_eq!(nine_b.qwen3_size(), Qwen3Size::B8);
1787        assert_eq!(nine_b.qwen3_bf16_config().hidden_size, 4096);
1788
1789        fs::remove_dir_all(base_dir).ok();
1790    }
1791
1792    #[test]
1793    fn flux2_text_encoder_paths_use_shards_or_t5_fallback() {
1794        let dir = temp_test_dir("mold-flux2-paths");
1795        let shard_a = touch(&dir, "encoder-1.safetensors");
1796        let shard_b = touch(&dir, "encoder-2.safetensors");
1797        let fallback = touch(&dir, "encoder.safetensors");
1798
1799        let sharded = Flux2Engine::new(
1800            "flux2-klein:q8".to_string(),
1801            flux2_model_paths(
1802                &dir,
1803                "transformer.gguf",
1804                vec![shard_a.clone(), shard_b.clone()],
1805                Some(fallback.clone()),
1806            ),
1807            None,
1808            LoadStrategy::Sequential,
1809            0,
1810            false,
1811            None,
1812        );
1813        assert_eq!(sharded.text_encoder_paths(), vec![shard_a, shard_b]);
1814
1815        let fallback_engine = Flux2Engine::new(
1816            "flux2-klein:q8".to_string(),
1817            flux2_model_paths(&dir, "transformer.gguf", vec![], Some(fallback.clone())),
1818            None,
1819            LoadStrategy::Sequential,
1820            0,
1821            false,
1822            None,
1823        );
1824        assert_eq!(fallback_engine.text_encoder_paths(), vec![fallback]);
1825
1826        fs::remove_dir_all(dir).ok();
1827    }
1828
1829    #[test]
1830    fn flux2_loads_qwen3_tokenizer_through_shared_pool() {
1831        let dir = temp_test_dir("mold-flux2-tokenizer-pool");
1832        let tokenizer_path = dir.join("tokenizer.json");
1833        tokenizers::Tokenizer::new(BPE::default())
1834            .save(&tokenizer_path, false)
1835            .unwrap();
1836
1837        let shared_pool = Arc::new(Mutex::new(SharedPool::new()));
1838        let pooled = shared_pool
1839            .lock()
1840            .unwrap()
1841            .load_tokenizer(&tokenizer_path)
1842            .unwrap();
1843
1844        let engine = Flux2Engine::new(
1845            "flux2-klein:q8".to_string(),
1846            flux2_model_paths(&dir, "transformer.gguf", vec![], None),
1847            None,
1848            LoadStrategy::Sequential,
1849            0,
1850            false,
1851            Some(shared_pool),
1852        );
1853
1854        let loaded = engine.load_text_tokenizer(&tokenizer_path).unwrap();
1855
1856        assert!(Arc::ptr_eq(&pooled, &loaded));
1857        fs::remove_dir_all(dir).ok();
1858    }
1859
1860    #[test]
1861    fn flux2_forced_offload_uses_sequential_generation_path() {
1862        let dir = temp_test_dir("mold-flux2-offload-sequential");
1863        let engine = Flux2Engine::new(
1864            "flux2-klein:bf16".to_string(),
1865            flux2_model_paths(&dir, "transformer.safetensors", vec![], None),
1866            None,
1867            LoadStrategy::Eager,
1868            0,
1869            true,
1870            None,
1871        );
1872
1873        assert!(
1874            engine.uses_sequential_generate_path(),
1875            "Flux.2 --offload requests must reach the engine and select the \
1876             staged generation path instead of being silently ignored"
1877        );
1878
1879        fs::remove_dir_all(dir).ok();
1880    }
1881
1882    #[test]
1883    fn flux2_offload_decision_gates_current_unsupported_cases() {
1884        assert_eq!(
1885            flux2_offload_decision(false, false, false),
1886            Flux2OffloadDecision::Disabled
1887        );
1888        assert_eq!(
1889            flux2_offload_decision(true, false, false),
1890            Flux2OffloadDecision::Selected
1891        );
1892        assert!(matches!(
1893            flux2_offload_decision(true, true, false),
1894            Flux2OffloadDecision::Unsupported(reason)
1895                if reason.contains("GGUF variants")
1896        ));
1897        assert!(matches!(
1898            flux2_offload_decision(true, false, true),
1899            Flux2OffloadDecision::Unsupported(reason)
1900                if reason.contains("LoRA")
1901        ));
1902    }
1903
1904    #[test]
1905    fn flux2_selected_bf16_offload_reaches_runtime_loader() {
1906        let dir = temp_test_dir("mold-flux2-offload-loader");
1907        let transformer = touch(&dir, "transformer.safetensors");
1908        let vae = touch(&dir, "vae.safetensors");
1909        let encoder = touch(&dir, "encoder.safetensors");
1910        let tokenizer = touch(&dir, "tokenizer.json");
1911        let mut engine = Flux2Engine::new(
1912            "flux2-klein:bf16".to_string(),
1913            ModelPaths {
1914                transformer,
1915                transformer_shards: vec![],
1916                vae,
1917                spatial_upscaler: None,
1918                temporal_upscaler: None,
1919                distilled_lora: None,
1920                t5_encoder: None,
1921                clip_encoder: None,
1922                t5_tokenizer: None,
1923                clip_tokenizer: None,
1924                clip_encoder_2: None,
1925                clip_tokenizer_2: None,
1926                text_encoder_files: vec![encoder],
1927                text_tokenizer: Some(tokenizer),
1928                decoder: None,
1929            },
1930            None,
1931            LoadStrategy::Sequential,
1932            0,
1933            true,
1934            None,
1935        );
1936        let cfg = engine.resolve_config();
1937        let txt_emb = Tensor::zeros((1, 1, cfg.context_in_dim), DType::F32, &Device::Cpu).unwrap();
1938        engine.prompt_cache.lock().unwrap().insert(
1939            prompt_text_key("a cat"),
1940            CachedTensor::from_tensor(&txt_emb).unwrap(),
1941        );
1942        let req = GenerateRequest {
1943            prompt: "a cat".to_string(),
1944            negative_prompt: None,
1945            model: "flux2-klein:bf16".to_string(),
1946            width: 64,
1947            height: 64,
1948            steps: 1,
1949            guidance: 0.0,
1950            seed: Some(1),
1951            batch_size: 1,
1952            output_format: None,
1953            embed_metadata: None,
1954            scheduler: None,
1955            cfg_plus: None,
1956            source_image: None,
1957            edit_images: None,
1958            strength: 1.0,
1959            mask_image: None,
1960            control_image: None,
1961            control_model: None,
1962            control_scale: 1.0,
1963            expand: None,
1964            original_prompt: None,
1965            lora: None,
1966            frames: None,
1967            fps: None,
1968            upscale_model: None,
1969            gif_preview: false,
1970            enable_audio: None,
1971            audio_file: None,
1972            audio_file_path: None,
1973            source_video: None,
1974            source_video_path: None,
1975            keyframes: None,
1976            pipeline: None,
1977            loras: None,
1978            retake_range: None,
1979            spatial_upscale: None,
1980            temporal_upscale: None,
1981            placement: Some(mold_core::types::DevicePlacement {
1982                text_encoders: mold_core::types::DeviceRef::Cpu,
1983                advanced: Some(mold_core::types::AdvancedPlacement {
1984                    transformer: mold_core::types::DeviceRef::Cpu,
1985                    vae: mold_core::types::DeviceRef::Cpu,
1986                    ..Default::default()
1987                }),
1988            }),
1989        };
1990
1991        let err = engine.generate_sequential(&req).unwrap_err().to_string();
1992
1993        assert!(
1994            !err.contains("streaming is not implemented yet"),
1995            "selected BF16 offload must reach the runtime loader, got: {err}"
1996        );
1997        fs::remove_dir_all(dir).ok();
1998    }
1999
2000    #[test]
2001    fn flux2_loads_vae_tensors_through_shared_pool() {
2002        let dir = temp_test_dir("mold-flux2-vae-pool");
2003        let vae_path = dir.join("vae.safetensors");
2004        let weight = 1.0f32.to_le_bytes();
2005        let mut tensors = HashMap::new();
2006        tensors.insert(
2007            "encoder.conv_in.weight".to_string(),
2008            TensorView::new(SafeDtype::F32, vec![1], &weight).unwrap(),
2009        );
2010        serialize_to_file(&tensors, &None, &vae_path).unwrap();
2011
2012        let shared_pool = Arc::new(Mutex::new(SharedPool::new()));
2013        let pooled = shared_pool
2014            .lock()
2015            .unwrap()
2016            .load_safetensors_cpu_tensors(std::slice::from_ref(&vae_path))
2017            .unwrap()
2018            .unwrap();
2019
2020        let engine = Flux2Engine::new(
2021            "flux2-klein:q8".to_string(),
2022            flux2_model_paths(&dir, "transformer.gguf", vec![], None),
2023            None,
2024            LoadStrategy::Sequential,
2025            0,
2026            false,
2027            Some(shared_pool),
2028        );
2029
2030        let loaded = engine.load_vae_cpu_tensors().unwrap().unwrap();
2031
2032        assert!(Arc::ptr_eq(&pooled, &loaded));
2033        fs::remove_dir_all(dir).ok();
2034    }
2035
2036    #[test]
2037    fn flux2_validate_paths_accepts_existing_files_and_returns_tokenizer() {
2038        let dir = temp_test_dir("mold-flux2-validate-ok");
2039        let transformer = touch(&dir, "transformer.gguf");
2040        let vae = touch(&dir, "vae.safetensors");
2041        let encoder = touch(&dir, "encoder.safetensors");
2042        let tokenizer = touch(&dir, "tokenizer.json");
2043
2044        let engine = Flux2Engine::new(
2045            "flux2-klein:q8".to_string(),
2046            ModelPaths {
2047                transformer,
2048                transformer_shards: vec![],
2049                vae,
2050                spatial_upscaler: None,
2051                temporal_upscaler: None,
2052                distilled_lora: None,
2053                t5_encoder: None,
2054                clip_encoder: None,
2055                t5_tokenizer: None,
2056                clip_tokenizer: None,
2057                clip_encoder_2: None,
2058                clip_tokenizer_2: None,
2059                text_encoder_files: vec![encoder],
2060                text_tokenizer: Some(tokenizer.clone()),
2061                decoder: None,
2062            },
2063            None,
2064            LoadStrategy::Sequential,
2065            0,
2066            false,
2067            None,
2068        );
2069
2070        assert_eq!(engine.validate_paths().unwrap(), tokenizer);
2071        assert!(engine.is_gguf_transformer());
2072
2073        fs::remove_dir_all(dir).ok();
2074    }
2075
2076    #[test]
2077    fn flux2_validate_paths_requires_text_encoder_paths() {
2078        let dir = temp_test_dir("mold-flux2-validate-missing");
2079        let transformer = touch(&dir, "transformer.safetensors");
2080        let vae = touch(&dir, "vae.safetensors");
2081        let tokenizer = touch(&dir, "tokenizer.json");
2082
2083        let engine = Flux2Engine::new(
2084            "flux2-klein:bf16".to_string(),
2085            ModelPaths {
2086                transformer,
2087                transformer_shards: vec![],
2088                vae,
2089                spatial_upscaler: None,
2090                temporal_upscaler: None,
2091                distilled_lora: None,
2092                t5_encoder: None,
2093                clip_encoder: None,
2094                t5_tokenizer: None,
2095                clip_tokenizer: None,
2096                clip_encoder_2: None,
2097                clip_tokenizer_2: None,
2098                text_encoder_files: vec![],
2099                text_tokenizer: Some(tokenizer),
2100                decoder: None,
2101            },
2102            None,
2103            LoadStrategy::Sequential,
2104            0,
2105            false,
2106            None,
2107        );
2108
2109        let err = engine.validate_paths().unwrap_err();
2110        assert!(err.to_string().contains("text encoder paths required"));
2111        assert!(!engine.is_gguf_transformer());
2112
2113        fs::remove_dir_all(dir).ok();
2114    }
2115}