Skip to main content

mold_inference/flux/
pipeline.rs

1use anyhow::{bail, Result};
2use candle_core::{DType, Device, IndexOp, Tensor};
3use candle_nn::VarBuilder;
4use candle_transformers::models::flux;
5use candle_transformers::quantized_var_builder;
6use mold_core::{GenerateRequest, GenerateResponse, ImageData, ModelPaths};
7use std::collections::{BTreeMap, HashMap};
8use std::path::{Path, PathBuf};
9use std::sync::{Arc, Mutex};
10use std::time::Instant;
11
12use crate::cache::{
13    clear_cache, prompt_text_key, restore_cached_tensor_pair, store_cached_tensor_pair,
14    CachedTensorPair, LruCache, DEFAULT_PROMPT_CACHE_CAPACITY,
15};
16use crate::device::{
17    check_memory_budget, effective_device_ref, fmt_gb, free_vram_bytes, memory_status_string,
18    preflight_memory_check, should_offload, should_use_gpu, usable_free_vram_bytes,
19    CLIP_VRAM_THRESHOLD, MIN_OFFLOAD_VRAM,
20};
21use crate::encoders;
22use crate::engine::{rand_seed, InferenceEngine, LoadStrategy, OptionRestoreGuard};
23use crate::engine_base::EngineBase;
24use crate::image::{build_output_metadata, encode_image};
25use crate::progress::{ProgressCallback, ProgressReporter};
26
27use super::transformer::FluxTransformer;
28
29/// Some FLUX safetensors checkpoints store transformer tensors at the root
30/// while others nest them under `model.diffusion_model`.
31fn flux_transformer_var_builder<'a>(vb: VarBuilder<'a>) -> VarBuilder<'a> {
32    if vb.contains_tensor("img_in.weight") {
33        vb
34    } else if vb.contains_tensor("model.diffusion_model.img_in.weight") {
35        vb.pp("model.diffusion_model")
36    } else if vb.contains_tensor("diffusion_model.img_in.weight") {
37        vb.pp("diffusion_model")
38    } else {
39        vb
40    }
41}
42
43/// Some FLUX single-file checkpoints bundle the VAE under a wrapper prefix
44/// while the candle FLUX autoencoder expects root `encoder.*` / `decoder.*`
45/// keys.
46fn flux_vae_var_builder<'a>(vb: VarBuilder<'a>) -> VarBuilder<'a> {
47    if vb.contains_tensor("encoder.conv_in.weight") {
48        vb
49    } else if vb.contains_tensor("first_stage_model.encoder.conv_in.weight") {
50        vb.pp("first_stage_model")
51    } else if vb.contains_tensor("vae.encoder.conv_in.weight") {
52        vb.pp("vae")
53    } else {
54        vb
55    }
56}
57
58/// Check if a FLUX safetensors checkpoint stores weights in FP8 (F8_E4M3).
59/// Uses candle's DType after loading a single small tensor on CPU (img_in.weight
60/// is typically only a few KB).
61fn flux_safetensors_transformer_is_fp8(path: &std::path::Path) -> Result<bool> {
62    let tensors = unsafe { candle_core::safetensors::MmapedSafetensors::multi(&[path])? };
63    for key in [
64        "img_in.weight",
65        "model.diffusion_model.img_in.weight",
66        "diffusion_model.img_in.weight",
67    ] {
68        if let Ok(tensor) = tensors.load(key, &Device::Cpu) {
69            return Ok(tensor.dtype() == DType::F8E4M3);
70        }
71    }
72    Ok(false)
73}
74
75fn flux_runtime_dtype(is_cuda: bool, is_quantized: bool, transformer_is_fp8: bool) -> DType {
76    if is_quantized {
77        if is_cuda {
78            DType::BF16
79        } else {
80            DType::F32
81        }
82    } else if is_cuda && transformer_is_fp8 {
83        // FP8 safetensors must go through F16 on CUDA (candle has a kernel naming
84        // bug that prevents direct CUDA FP8→BF16 casts). The lazy mmap VarBuilder
85        // handles dtype conversion during model construction.
86        DType::F16
87    } else if is_cuda {
88        DType::BF16
89    } else {
90        DType::F32
91    }
92}
93
94/// Path for the Q8 GGUF cache of an FP8 safetensors file.
95/// Cache key: stem + file size + FNV-1a hash of 4KB sampled from the weight
96/// data region (past the JSON header). This avoids collisions between
97/// different fine-tunes that share the same tensor layout and header.
98fn fp8_gguf_cache_path(path: &Path) -> PathBuf {
99    use std::io::{Read, Seek, SeekFrom};
100    let stem = path
101        .file_stem()
102        .and_then(|s| s.to_str())
103        .unwrap_or("transformer");
104    let size = std::fs::metadata(path).map(|m| m.len()).unwrap_or(0);
105    // Sample 4KB from the weight data region (past the safetensors JSON header).
106    // The header is typically ~30-60KB; sampling from 25% into the file ensures
107    // we're reading actual weight data, not the identical JSON layout.
108    let sample_offset = size / 4;
109    let content_hash = std::fs::File::open(path)
110        .and_then(|mut f| {
111            f.seek(SeekFrom::Start(sample_offset))?;
112            let mut buf = vec![0u8; 4096];
113            let n = f.read(&mut buf)?;
114            buf.truncate(n);
115            Ok(buf)
116        })
117        .map(|buf| {
118            let mut h: u64 = 0xcbf2_9ce4_8422_2325; // FNV-1a offset basis
119            for &b in &buf {
120                h ^= b as u64;
121                h = h.wrapping_mul(0x0100_0000_01b3); // FNV-1a prime
122            }
123            format!("{h:016x}")
124        })
125        .unwrap_or_else(|_| "0".to_string());
126    let cache_root = mold_core::Config::mold_dir()
127        .unwrap_or_else(|| PathBuf::from(".mold"))
128        .join("cache")
129        .join("flux-q8");
130    cache_root.join(format!("{stem}-{size}-{content_hash}.q8_0.gguf"))
131}
132
133fn q8_0_can_quantize_dims(dims: &[usize]) -> bool {
134    if dims.len() < 2 {
135        return false;
136    }
137    let block_size = candle_core::quantized::GgmlDType::Q8_0.block_size();
138    dims.last()
139        .is_some_and(|last_dim| *last_dim >= block_size && *last_dim % block_size == 0)
140}
141
142fn fp8_cache_should_skip_tensor(name: &str, dims: &[usize]) -> bool {
143    dims.is_empty() || name.starts_with("text_encoders.")
144}
145
146fn fp8_gguf_tmp_path(cache_path: &Path) -> PathBuf {
147    static NEXT_TMP: std::sync::atomic::AtomicU64 = std::sync::atomic::AtomicU64::new(0);
148    let seq = NEXT_TMP.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
149    cache_path.with_extension(format!("tmp.{}.{}", std::process::id(), seq))
150}
151
152/// Convert an FP8 safetensors checkpoint to Q8_0 GGUF (one-time).
153///
154/// FP8 safetensors cannot run directly through candle on a 24 GB card because
155/// expanding to F16/BF16 doubles the VRAM requirement. Q8_0 GGUF keeps the
156/// model at ~12 GB and uses candle's efficient quantized matmul path.
157fn ensure_fp8_gguf_cache(path: &Path, progress: &ProgressReporter) -> Result<PathBuf> {
158    let cache_path = fp8_gguf_cache_path(path);
159    if cache_path.exists() {
160        progress.info(&format!("Using cached Q8 GGUF: {}", cache_path.display()));
161        return Ok(cache_path);
162    }
163
164    let parent = cache_path
165        .parent()
166        .ok_or_else(|| anyhow::anyhow!("invalid cache path: {}", cache_path.display()))?;
167
168    // Clean up orphaned caches from older naming schemes only.
169    // v1: {stem}.q8_0.gguf  (no size/hash — exactly "stem.q8_0.gguf")
170    // v2: {stem}-{size}.q8_0.gguf  (size only, no content hash — one dash)
171    // Current v3: {stem}-{size}-{hash}.q8_0.gguf  (two dashes — NOT cleaned)
172    // We only remove v1/v2 formats. Valid v3 caches for other checkpoints
173    // (different size/hash) are preserved to avoid expensive re-quantization.
174    let stem = path
175        .file_stem()
176        .and_then(|s| s.to_str())
177        .unwrap_or("transformer");
178    std::fs::create_dir_all(parent)?;
179    let old_v1 = parent.join(format!("{stem}.q8_0.gguf"));
180    if old_v1.exists() {
181        tracing::info!(path = %old_v1.display(), "removing v1 orphaned FP8 cache");
182        let _ = std::fs::remove_file(&old_v1);
183    }
184    // v2 format: {stem}-{digits}.q8_0.gguf (one dash, no hash)
185    if let Ok(entries) = std::fs::read_dir(parent) {
186        let v2_prefix = format!("{stem}-");
187        let suffix = ".q8_0.gguf";
188        for entry in entries.flatten() {
189            let name = entry.file_name();
190            let Some(name_str) = name.to_str() else {
191                continue;
192            };
193            if !name_str.starts_with(&v2_prefix) || !name_str.ends_with(suffix) {
194                continue;
195            }
196            // Extract the middle part between prefix and suffix
197            let middle = &name_str[v2_prefix.len()..name_str.len() - suffix.len()];
198            // v2 has no dash in the middle (just digits for size).
199            // v3 has a dash (size-hash). Only remove v2.
200            if !middle.contains('-') && middle.chars().all(|c| c.is_ascii_digit()) {
201                tracing::info!(path = %entry.path().display(), "removing v2 orphaned FP8 cache");
202                let _ = std::fs::remove_file(entry.path());
203            }
204        }
205    }
206
207    progress.info("Converting FP8 checkpoint to Q8 GGUF cache (one-time, may take a few minutes)");
208    tracing::info!(
209        source = %path.display(),
210        cache = %cache_path.display(),
211        "converting FP8 safetensors to Q8_0 GGUF cache"
212    );
213
214    let tensors = unsafe { candle_core::safetensors::MmapedSafetensors::multi(&[path])? };
215
216    // Detect and strip the common prefix used in some checkpoints
217    let prefix = if tensors.get("img_in.weight").is_ok() {
218        ""
219    } else if tensors.get("model.diffusion_model.img_in.weight").is_ok() {
220        "model.diffusion_model."
221    } else if tensors.get("diffusion_model.img_in.weight").is_ok() {
222        "diffusion_model."
223    } else {
224        ""
225    };
226
227    // Enumerate all tensor names via MmapedSafetensors::tensors()
228    let all_names: Vec<String> = tensors
229        .tensors()
230        .into_iter()
231        .map(|(name, _)| name)
232        .collect();
233
234    let mut qtensors: Vec<(String, candle_core::quantized::QTensor)> = Vec::new();
235
236    let total = all_names.len();
237    for (i, name) in all_names.iter().enumerate() {
238        if (i + 1) % 50 == 0 || i + 1 == total {
239            progress.info(&format!("Quantizing tensor {}/{total}", i + 1));
240        }
241
242        let tensor = tensors.load(name, &Device::Cpu)?;
243        // Strip prefix for GGUF (quantized model expects unprefixed names)
244        let out_name = if !prefix.is_empty() && name.starts_with(prefix) {
245            name[prefix.len()..].to_string()
246        } else {
247            name.clone()
248        };
249
250        if fp8_cache_should_skip_tensor(&out_name, tensor.dims()) {
251            continue;
252        }
253
254        let can_quantize = q8_0_can_quantize_dims(tensor.dims());
255
256        let qt = if can_quantize {
257            candle_core::quantized::QTensor::quantize(
258                &tensor,
259                candle_core::quantized::GgmlDType::Q8_0,
260            )?
261        } else {
262            // Small/odd-shaped tensors (norms, biases): store as F32
263            candle_core::quantized::QTensor::quantize(
264                &tensor,
265                candle_core::quantized::GgmlDType::F32,
266            )?
267        };
268        qtensors.push((out_name, qt));
269    }
270
271    // Write GGUF cache (clean up temp file on error)
272    let tmp_path = fp8_gguf_tmp_path(&cache_path);
273    let write_result = (|| -> Result<()> {
274        let file = std::fs::File::create(&tmp_path)?;
275        let mut writer = std::io::BufWriter::new(file);
276        let tensor_refs: Vec<(&str, &candle_core::quantized::QTensor)> =
277            qtensors.iter().map(|(n, q)| (n.as_str(), q)).collect();
278        candle_core::quantized::gguf_file::write(&mut writer, &[], &tensor_refs)?;
279        Ok(())
280    })();
281    if let Err(e) = write_result {
282        let _ = std::fs::remove_file(&tmp_path);
283        return Err(e);
284    }
285    if cache_path.exists() {
286        let _ = std::fs::remove_file(&tmp_path);
287        progress.info(&format!("Using cached Q8 GGUF: {}", cache_path.display()));
288        return Ok(cache_path);
289    }
290    std::fs::rename(&tmp_path, &cache_path)?;
291
292    progress.info(&format!("Q8 GGUF cache created: {}", cache_path.display()));
293    tracing::info!(cache = %cache_path.display(), "FP8→Q8_0 GGUF cache created");
294    Ok(cache_path)
295}
296
297// ── City96-format GGUF embedding patching ──────────────────────────────────
298
299/// Embedding tensors required by all FLUX models (schnell and dev).
300const FLUX_EMBEDDING_TENSORS: &[&str] = &[
301    "img_in.weight",
302    "img_in.bias",
303    "time_in.in_layer.weight",
304    "time_in.in_layer.bias",
305    "time_in.out_layer.weight",
306    "time_in.out_layer.bias",
307    "vector_in.in_layer.weight",
308    "vector_in.in_layer.bias",
309    "vector_in.out_layer.weight",
310    "vector_in.out_layer.bias",
311];
312
313/// Additional embedding tensors for FLUX-dev (guidance-based) models.
314const FLUX_GUIDANCE_EMBEDDING_TENSORS: &[&str] = &[
315    "guidance_in.in_layer.weight",
316    "guidance_in.in_layer.bias",
317    "guidance_in.out_layer.weight",
318    "guidance_in.out_layer.bias",
319];
320
321/// Lightweight check: does a GGUF file contain the FLUX embedding layers?
322/// Reads only the GGUF header (tensor_infos), not the tensor data.
323///
324/// Relies on the city96-format property that embedding tensors are either
325/// all present or all absent. A GGUF with `img_in.weight` but missing other
326/// embeddings would pass this check.
327fn gguf_has_embeddings(path: &Path) -> Result<bool> {
328    let mut file = std::fs::File::open(path)?;
329    let content = candle_core::quantized::gguf_file::Content::read(&mut file)?;
330    Ok(content.tensor_infos.contains_key("img_in.weight"))
331}
332
333/// Does a GGUF contain the flux-dev-only `guidance_in` tensors? Schnell GGUFs
334/// return false because the schnell architecture is distilled without guidance.
335fn gguf_has_guidance(path: &Path) -> Result<bool> {
336    let mut file = std::fs::File::open(path)?;
337    let content = candle_core::quantized::gguf_file::Content::read(&mut file)?;
338    Ok(content
339        .tensor_infos
340        .contains_key("guidance_in.in_layer.weight"))
341}
342
343/// Search for a downloaded FLUX GGUF that contains complete embeddings.
344///
345/// Prefers larger quantizations (more likely downloaded) first. When
346/// `needs_guidance` is true, schnell candidates are skipped and dev candidates
347/// are verified to contain `guidance_in` tensors — a schnell GGUF passes the
348/// basic `img_in` check but cannot supply `guidance_in` for a dev-family target.
349///
350/// When `models_dir_override` is `Some`, searches that directory instead of
351/// the config-resolved models dir (used by tests to avoid global state).
352fn find_flux_reference_gguf(
353    needs_guidance: bool,
354    models_dir_override: Option<&Path>,
355) -> Option<PathBuf> {
356    let config = mold_core::Config::load_or_default();
357    let models_dir = models_dir_override
358        .map(PathBuf::from)
359        .unwrap_or_else(|| config.resolved_models_dir());
360
361    // Dev candidates satisfy both schnell and dev targets (schnell tensors are a
362    // subset of dev). Schnell candidates only satisfy schnell targets.
363    // flux-krea is a dev-family fine-tune shipped as complete GGUFs by
364    // QuantStack, so it carries the full embedding set including guidance_in —
365    // fall back to it before asking the user to download flux-dev.
366    let mut candidates: Vec<&str> = vec![
367        "flux-dev:q8",
368        "flux-dev:q6",
369        "flux-dev:q4",
370        "flux-krea:q8",
371        "flux-krea:q6",
372        "flux-krea:q4",
373    ];
374    if !needs_guidance {
375        candidates.extend(["flux-schnell:q8", "flux-schnell:q4"]);
376    }
377
378    for name in candidates {
379        let Some(manifest) = mold_core::manifest::find_manifest(name) else {
380            continue;
381        };
382        // Find the transformer file in the manifest
383        let Some(xformer_file) = manifest
384            .files
385            .iter()
386            .find(|f| f.component == mold_core::manifest::ModelComponent::Transformer)
387        else {
388            continue;
389        };
390        let xformer_path =
391            models_dir.join(mold_core::manifest::storage_path(manifest, xformer_file));
392        if !xformer_path.exists() {
393            continue;
394        }
395        // Verify it actually has the embeddings (don't assume)
396        match gguf_has_embeddings(&xformer_path) {
397            Ok(true) => {
398                if needs_guidance {
399                    match gguf_has_guidance(&xformer_path) {
400                        Ok(true) => {}
401                        Ok(false) => {
402                            tracing::debug!(
403                                model = name,
404                                "reference candidate lacks guidance_in, skipping for dev target"
405                            );
406                            continue;
407                        }
408                        Err(e) => {
409                            tracing::debug!(
410                                model = name,
411                                err = %e,
412                                "failed to probe guidance tensors"
413                            );
414                            continue;
415                        }
416                    }
417                }
418                tracing::info!(
419                    reference = %xformer_path.display(),
420                    model = name,
421                    needs_guidance,
422                    "found reference FLUX GGUF with embeddings"
423                );
424                return Some(xformer_path);
425            }
426            Ok(false) => {
427                tracing::debug!(
428                    model = name,
429                    "reference candidate also missing embeddings, skipping"
430                );
431            }
432            Err(e) => {
433                tracing::debug!(model = name, err = %e, "failed to probe reference candidate");
434            }
435        }
436    }
437    None
438}
439
440/// Cache path for a GGUF patched with missing embedding layers.
441/// Same FNV-1a content hashing scheme as `fp8_gguf_cache_path`.
442fn embedding_patched_cache_path(path: &Path) -> PathBuf {
443    use std::io::{Read, Seek, SeekFrom};
444    let stem = path
445        .file_stem()
446        .and_then(|s| s.to_str())
447        .unwrap_or("transformer");
448    let size = std::fs::metadata(path).map(|m| m.len()).unwrap_or(0);
449    let sample_offset = size / 4;
450    let content_hash = std::fs::File::open(path)
451        .and_then(|mut f| {
452            f.seek(SeekFrom::Start(sample_offset))?;
453            let mut buf = vec![0u8; 4096];
454            let n = f.read(&mut buf)?;
455            buf.truncate(n);
456            Ok(buf)
457        })
458        .map(|buf| {
459            let mut h: u64 = 0xcbf2_9ce4_8422_2325;
460            for &b in &buf {
461                h ^= b as u64;
462                h = h.wrapping_mul(0x0100_0000_01b3);
463            }
464            format!("{h:016x}")
465        })
466        .unwrap_or_else(|_| "0".to_string());
467    let cache_root = mold_core::Config::mold_dir()
468        .unwrap_or_else(|| PathBuf::from(".mold"))
469        .join("cache")
470        .join("flux-embeddings");
471    cache_root.join(format!("{stem}-{size}-{content_hash}.patched.gguf"))
472}
473
474/// Ensure a GGUF file has complete FLUX embedding layers.
475///
476/// City96-format GGUFs (used by community fine-tune quantizations like
477/// UltraReal) only include the diffusion blocks but omit input embedding
478/// layers (`img_in`, `time_in`, `vector_in`, `guidance_in`). This function
479/// detects incomplete GGUFs and patches them by sourcing the missing
480/// embeddings from a reference FLUX GGUF (e.g. flux-dev:q8).
481///
482/// Returns the original path if the GGUF is already complete, or the path
483/// to a patched cache file.
484///
485/// `models_dir_override` is forwarded to `find_flux_reference_gguf` and
486/// only used by tests to avoid mutating process-global environment variables.
487fn ensure_gguf_embeddings(
488    path: &Path,
489    is_schnell: bool,
490    progress: &ProgressReporter,
491    models_dir_override: Option<&Path>,
492) -> Result<PathBuf> {
493    let cache_path = embedding_patched_cache_path(path);
494    if cache_path.exists() {
495        progress.info(&format!(
496            "Using cached embedding-patched GGUF: {}",
497            cache_path.display()
498        ));
499        return Ok(cache_path);
500    }
501
502    // Probe whether embeddings are actually missing
503    if gguf_has_embeddings(path)? {
504        return Ok(path.to_path_buf());
505    }
506
507    progress.info(
508        "GGUF is missing FLUX embedding layers (city96 format) — patching from reference model",
509    );
510    tracing::info!(
511        path = %path.display(),
512        is_schnell,
513        "GGUF missing embedding layers, searching for reference model"
514    );
515
516    let source_name = path
517        .file_name()
518        .and_then(|n| n.to_str())
519        .unwrap_or("<unknown>");
520    let needs_guidance = !is_schnell;
521    let reference_path =
522        find_flux_reference_gguf(needs_guidance, models_dir_override).ok_or_else(|| {
523            let family = if needs_guidance { "dev" } else { "schnell" };
524            anyhow::anyhow!(
525                "{source_name} is a city96-format GGUF that ships only the diffusion \
526                 blocks — its FLUX input embedding layers (img_in, time_in, vector_in{guidance}) \
527                 must be sourced from a complete flux-{family} GGUF, but none is downloaded.\n\n\
528                 To fix this:\n\n  mold pull flux-dev:q8\n\n\
529                 Then retry — mold will patch the incomplete GGUF from the reference.",
530                guidance = if needs_guidance { ", guidance_in" } else { "" },
531            )
532        })?;
533
534    // Determine which embedding tensors we need
535    let mut needed: Vec<&str> = FLUX_EMBEDDING_TENSORS.to_vec();
536    if !is_schnell {
537        needed.extend_from_slice(FLUX_GUIDANCE_EMBEDDING_TENSORS);
538    }
539
540    // Read source (incomplete) GGUF
541    progress.info("Reading source GGUF tensors...");
542    let mut src_file = std::fs::File::open(path)?;
543    let src_content = candle_core::quantized::gguf_file::Content::read(&mut src_file)?;
544
545    // Read only the needed embedding tensors from the reference GGUF
546    progress.info(&format!(
547        "Extracting {} embedding tensors from reference: {}",
548        needed.len(),
549        reference_path
550            .file_name()
551            .and_then(|n| n.to_str())
552            .unwrap_or("?")
553    ));
554    let mut ref_file = std::fs::File::open(&reference_path)?;
555    let ref_content = candle_core::quantized::gguf_file::Content::read(&mut ref_file)?;
556
557    let cpu = Device::Cpu;
558
559    // Load all source tensors
560    let mut qtensors: Vec<(String, candle_core::quantized::QTensor)> = Vec::new();
561    let total = src_content.tensor_infos.len();
562    for (i, name) in src_content.tensor_infos.keys().enumerate() {
563        if (i + 1) % 100 == 0 || i + 1 == total {
564            progress.info(&format!("Loading source tensor {}/{total}", i + 1));
565        }
566        let tensor = src_content.tensor(&mut src_file, name, &cpu)?;
567        qtensors.push((name.clone(), tensor));
568    }
569
570    // Load missing embedding tensors from reference
571    let mut patched_count = 0usize;
572    for name in &needed {
573        if src_content.tensor_infos.contains_key(*name) {
574            continue; // already present in source
575        }
576        if !ref_content.tensor_infos.contains_key(*name) {
577            bail!(
578                "while patching {source_name}: the only downloaded reference ({}) \
579                 is also missing '{name}'. This model needs a complete flux-dev GGUF \
580                 — run 'mold pull flux-dev:q8' and retry.",
581                reference_path
582                    .file_name()
583                    .and_then(|n| n.to_str())
584                    .unwrap_or("<unknown>"),
585            );
586        }
587        let tensor = ref_content.tensor(&mut ref_file, name, &cpu)?;
588        tracing::debug!(tensor = name, "patching embedding tensor from reference");
589        qtensors.push((name.to_string(), tensor));
590        patched_count += 1;
591    }
592
593    progress.info(&format!(
594        "Patched {patched_count} embedding tensors from reference"
595    ));
596
597    // Write patched GGUF
598    let parent = cache_path
599        .parent()
600        .ok_or_else(|| anyhow::anyhow!("invalid cache path: {}", cache_path.display()))?;
601    std::fs::create_dir_all(parent)?;
602    let tmp_path = cache_path.with_extension(format!("tmp.{}", std::process::id()));
603    let write_result = (|| -> Result<()> {
604        let file = std::fs::File::create(&tmp_path)?;
605        let mut writer = std::io::BufWriter::new(file);
606        let tensor_refs: Vec<(&str, &candle_core::quantized::QTensor)> =
607            qtensors.iter().map(|(n, q)| (n.as_str(), q)).collect();
608        candle_core::quantized::gguf_file::write(&mut writer, &[], &tensor_refs)?;
609        Ok(())
610    })();
611    if let Err(e) = write_result {
612        let _ = std::fs::remove_file(&tmp_path);
613        return Err(e);
614    }
615    std::fs::rename(&tmp_path, &cache_path)?;
616
617    progress.info(&format!(
618        "Embedding-patched GGUF cache created: {}",
619        cache_path.display()
620    ));
621    tracing::info!(
622        cache = %cache_path.display(),
623        patched_count,
624        "embedding-patched GGUF cache created"
625    );
626    Ok(cache_path)
627}
628
629fn flux_safetensors_var_builder<'a>(
630    path: &std::path::Path,
631    dtype: DType,
632    device: &Device,
633    component: &str,
634    progress: &ProgressReporter,
635) -> Result<VarBuilder<'a>> {
636    let aliases = flux_rms_norm_scale_aliases(path)?;
637    if aliases.is_empty() {
638        crate::weight_loader::load_safetensors_with_progress(
639            std::slice::from_ref(&path),
640            dtype,
641            device,
642            component,
643            progress,
644        )
645    } else {
646        tracing::info!(
647            alias_count = aliases.len(),
648            path = %path.display(),
649            "FLUX checkpoint uses RMSNorm .weight keys; aliasing .scale lookups"
650        );
651        crate::weight_loader::load_safetensors_with_aliases(
652            std::slice::from_ref(&path),
653            dtype,
654            device,
655            component,
656            progress,
657            aliases,
658        )
659    }
660}
661
662fn flux_rms_norm_scale_aliases(path: &std::path::Path) -> Result<BTreeMap<String, String>> {
663    let tensors = unsafe { candle_core::safetensors::MmapedSafetensors::multi(&[path])? };
664    let mut aliases = BTreeMap::new();
665    for prefix in ["", "model.diffusion_model.", "diffusion_model."] {
666        for i in 0..64 {
667            for stream in ["img_attn", "txt_attn"] {
668                for norm in ["query_norm", "key_norm"] {
669                    let target = format!("{prefix}double_blocks.{i}.{stream}.norm.{norm}.scale");
670                    let source = format!("{prefix}double_blocks.{i}.{stream}.norm.{norm}.weight");
671                    if tensors.get(&target).is_err() && tensors.get(&source).is_ok() {
672                        aliases.insert(target, source);
673                    }
674                }
675            }
676        }
677        for i in 0..128 {
678            for norm in ["query_norm", "key_norm"] {
679                let target = format!("{prefix}single_blocks.{i}.norm.{norm}.scale");
680                let source = format!("{prefix}single_blocks.{i}.norm.{norm}.weight");
681                if tensors.get(&target).is_err() && tensors.get(&source).is_ok() {
682                    aliases.insert(target, source);
683                }
684            }
685        }
686    }
687    Ok(aliases)
688}
689
690/// Build a LoRA-patching VarBuilder that wraps mmap'd base weights.
691///
692/// Uses a custom `SimpleBackend` that intercepts every `vb.get()` call during
693/// model construction.  Each tensor loads from mmap directly to GPU with LoRA
694/// deltas applied inline — identical memory profile to the non-LoRA mmap path.
695///
696/// Multi-LoRA: pass a slice with more than one weight and the deltas merge
697/// additively. Each adapter's contribution is independently cached (per
698/// path + scale) so a stack of (cinematic, dramatic-light) reuses both
699/// matmuls when the user toggles either back on later.
700fn flux_lora_var_builder<'a>(
701    transformer_path: &Path,
702    loras: &[mold_core::LoraWeight],
703    dtype: DType,
704    device: &Device,
705    progress: &ProgressReporter,
706    delta_cache: Option<std::sync::Arc<std::sync::Mutex<super::lora::LoraDeltaCache>>>,
707) -> Result<VarBuilder<'a>> {
708    use super::lora;
709
710    let adapters: Vec<std::sync::Arc<lora::LoraAdapter>> = loras
711        .iter()
712        .map(|w| {
713            progress.info("Loading LoRA adapter");
714            let adapter = lora::get_or_load_adapter(Path::new(&w.path))?;
715            progress.info(&format!(
716                "LoRA: {} layers, rank {}, scale {:.2}",
717                adapter.layers.len(),
718                adapter.rank,
719                w.scale,
720            ));
721            anyhow::Ok(adapter)
722        })
723        .collect::<Result<_>>()?;
724
725    let specs: Vec<lora::LoraSpec<'_>> = adapters
726        .iter()
727        .zip(loras.iter())
728        .map(|(adapter, w)| lora::LoraSpec {
729            adapter: adapter.as_ref(),
730            scale: w.scale,
731            path_hash: lora_path_hash(&w.path),
732        })
733        .collect();
734
735    lora::lora_var_builder(
736        transformer_path,
737        &specs,
738        dtype,
739        device,
740        progress,
741        delta_cache,
742    )
743}
744
745/// Stable hash for a LoRA file path. Used as the per-LoRA cache-key
746/// component so the delta cache survives transformer rebuilds and
747/// disambiguates adapters in a multi-LoRA stack.
748fn lora_path_hash(path: &str) -> u64 {
749    use std::hash::{Hash, Hasher};
750    let mut hasher = std::collections::hash_map::DefaultHasher::new();
751    path.hash(&mut hasher);
752    hasher.finish()
753}
754
755/// Same wrapper for the GGUF (quantized) path.
756fn flux_gguf_lora_var_builder(
757    transformer_path: &Path,
758    loras: &[mold_core::LoraWeight],
759    device: &Device,
760    progress: &ProgressReporter,
761    delta_cache: Option<std::sync::Arc<std::sync::Mutex<super::lora::LoraDeltaCache>>>,
762) -> Result<candle_transformers::quantized_var_builder::VarBuilder> {
763    use super::lora;
764
765    let adapters: Vec<std::sync::Arc<lora::LoraAdapter>> = loras
766        .iter()
767        .map(|w| {
768            progress.info("Loading LoRA adapter");
769            let adapter = lora::get_or_load_adapter(Path::new(&w.path))?;
770            progress.info(&format!(
771                "LoRA: {} layers, rank {}, scale {:.2}",
772                adapter.layers.len(),
773                adapter.rank,
774                w.scale,
775            ));
776            anyhow::Ok(adapter)
777        })
778        .collect::<Result<_>>()?;
779
780    let specs: Vec<lora::LoraSpec<'_>> = adapters
781        .iter()
782        .zip(loras.iter())
783        .map(|(adapter, w)| lora::LoraSpec {
784            adapter: adapter.as_ref(),
785            scale: w.scale,
786            path_hash: lora_path_hash(&w.path),
787        })
788        .collect();
789
790    lora::gguf_lora_var_builder(transformer_path, &specs, device, progress, delta_cache)
791}
792
793/// Three-state opt-in for bypass-mode LoRA. `auto` enables bypass on
794/// every supported path: the offload transformer (avoids the ~24 GB
795/// CPU-resident BF16 merge) and the GGUF transformer (avoids the
796/// minutes-long, ~95 GB peak dequant→merge→requant cycle on Q8 with
797/// a stack of LoRAs). `on` forces bypass; `off` reverts to the
798/// legacy `flux_lora_var_builder` / `gguf_lora_var_builder` so users
799/// can regression-check a build.
800#[derive(Clone, Copy, Debug, Eq, PartialEq)]
801enum LoraBypassMode {
802    Auto,
803    On,
804    Off,
805}
806
807impl LoraBypassMode {
808    fn from_env() -> Self {
809        match std::env::var("MOLD_LORA_BYPASS")
810            .ok()
811            .as_deref()
812            .map(str::trim)
813            .map(str::to_ascii_lowercase)
814            .as_deref()
815        {
816            Some("on") | Some("1") | Some("true") => Self::On,
817            Some("off") | Some("0") | Some("false") => Self::Off,
818            _ => Self::Auto,
819        }
820    }
821}
822
823/// Build a [`super::lora_bypass::LoraRegistry`] for any bypass-capable
824/// path (offload or GGUF/quantized).
825///
826/// Adapters are placed on `device` at `dtype` (typically GPU + BF16) so
827/// the per-step path never round-trips them CPU↔GPU. Both paths use the
828/// same registry shape: keys are FLUX candle tensor names, values are
829/// the bypass adapters that fire each time that Linear runs forward.
830///
831/// Returns `Ok(None)` when `loras` is empty so callers keep their
832/// no-LoRA hot path.
833fn build_lora_registry(
834    loras: &[mold_core::LoraWeight],
835    cfg: &flux::model::Config,
836    device: &Device,
837    dtype: DType,
838    progress: &ProgressReporter,
839) -> Result<Option<super::lora_bypass::LoraRegistry>> {
840    use super::lora;
841    use super::lora_bypass;
842
843    if loras.is_empty() {
844        return Ok(None);
845    }
846
847    let adapters: Vec<lora::LoraAdapter> = loras
848        .iter()
849        .map(|w| {
850            progress.info("Loading LoRA adapter (bypass)");
851            let adapter = lora::LoraAdapter::load(Path::new(&w.path))?;
852            progress.info(&format!(
853                "LoRA: {} layers, rank {}, scale {:.2}",
854                adapter.layers.len(),
855                adapter.rank,
856                w.scale,
857            ));
858            anyhow::Ok(adapter)
859        })
860        .collect::<Result<_>>()?;
861
862    let specs: Vec<lora::LoraSpec<'_>> = adapters
863        .iter()
864        .zip(loras.iter())
865        .map(|(adapter, w)| lora::LoraSpec {
866            adapter,
867            scale: w.scale,
868            path_hash: lora_path_hash(&w.path),
869        })
870        .collect();
871
872    // Pre-compute the fused linear out-row counts that bypass-mode
873    // needs to translate component-index targets (e.g. "Q only") into
874    // absolute slice offsets.
875    let h = cfg.hidden_size;
876    let mlp_sz = (h as f64 * cfg.mlp_ratio) as usize;
877    let mut linear_out_dims: std::collections::HashMap<String, usize> =
878        std::collections::HashMap::new();
879    for idx in 0..cfg.depth {
880        // Double blocks: img_attn.qkv / txt_attn.qkv each 3*h.
881        linear_out_dims.insert(format!("double_blocks.{idx}.img_attn.qkv.weight"), 3 * h);
882        linear_out_dims.insert(format!("double_blocks.{idx}.txt_attn.qkv.weight"), 3 * h);
883    }
884    for idx in 0..cfg.depth_single_blocks {
885        // Single block linear1 fuses [Q, K, V, MLP] = 3*h + mlp_sz.
886        linear_out_dims.insert(
887            format!("single_blocks.{idx}.linear1.weight"),
888            3 * h + mlp_sz,
889        );
890    }
891
892    let registry = lora_bypass::build_registry(&specs, &linear_out_dims, device, dtype)?;
893    progress.info(&format!(
894        "LoRA bypass: {} target tensors, adapters resident on {device:?}",
895        registry.len()
896    ));
897    Ok(Some(registry))
898}
899
900/// Resolve the effective LoRA list for a request.
901///
902/// Wire format intentionally accepts both `lora` (single) and `loras`
903/// (plural) for back-compat with older clients. When both are set,
904/// `loras` wins — single-form callers haven't been updated yet but
905/// new clients always populate the plural shape.
906///
907/// Entries whose `scale.abs() < ZERO_SCALE_EPS` are dropped: a slider
908/// pinned to zero is a no-op patch and forcing the transformer to
909/// rebuild for it is pure overhead. A `tracing::debug!` records each
910/// drop so a user wondering "why didn't my LoRA apply" can spot it
911/// in `RUST_LOG=debug` output.
912pub(crate) fn effective_loras(req: &mold_core::GenerateRequest) -> Vec<mold_core::LoraWeight> {
913    /// Threshold below which a LoRA scale is treated as off. Matches
914    /// the precision of an f64 scrubbed by a UI slider — anything
915    /// closer to zero than this is the user nudging the slider, not
916    /// a deliberate negative weight.
917    const ZERO_SCALE_EPS: f64 = 1e-8;
918
919    let raw: Vec<mold_core::LoraWeight> = if let Some(plural) = &req.loras {
920        if !plural.is_empty() {
921            plural.clone()
922        } else {
923            req.lora.iter().cloned().collect()
924        }
925    } else {
926        req.lora.iter().cloned().collect()
927    };
928
929    raw.into_iter()
930        .filter(|w| {
931            let keep = w.scale.abs() > ZERO_SCALE_EPS;
932            if !keep {
933                tracing::debug!(
934                    path = w.path.as_str(),
935                    scale = w.scale,
936                    "dropping zero-scale LoRA from effective stack"
937                );
938            }
939            keep
940        })
941        .collect()
942}
943
944/// Loaded FLUX model components, ready for inference.
945/// FLUX transformer and VAE always run on GPU. T5 and CLIP run on GPU or CPU
946/// depending on available VRAM (checked at load time after the transformer is loaded).
947/// When T5/CLIP are loaded on GPU, they are dropped after encoding to free VRAM
948/// for the denoising pass (their weights are only needed for prompt encoding).
949struct LoadedFlux {
950    /// None after being dropped for VAE decode VRAM; reloaded on next generate.
951    flux_model: Option<FluxTransformer>,
952    t5: encoders::t5::T5Encoder,
953    clip: encoders::clip::ClipEncoder,
954    vae: flux::autoencoder::AutoEncoder,
955    /// GPU device for FLUX transformer + VAE
956    device: Device,
957    dtype: DType,
958    /// Effective VAE dtype after `MOLD_VAE_DTYPE` resolution. Stored so the
959    /// post-denoise cast and the decode forward pass agree on precision —
960    /// the eager path loads the VAE once at startup so this is captured at
961    /// load time and persists for the engine's lifetime. Sequential reloads
962    /// re-resolve per request.
963    vae_dtype: DType,
964    is_schnell: bool,
965    /// True if using quantized GGUF model (state tensors must be F32)
966    is_quantized: bool,
967    /// Resolved transformer path (may be a GGUF cache for FP8 models).
968    transformer_path: PathBuf,
969    /// The actual T5 encoder path used (may be a quantized GGUF, not the original FP16 path).
970    t5_encoder_path: std::path::PathBuf,
971}
972
973/// Fingerprint of a single LoRA adapter (path + scale). Used to detect
974/// when the active LoRA stack has changed so we know to rebuild the
975/// transformer; an unchanged stack reuses the previously merged weights.
976#[derive(Clone, PartialEq, Eq)]
977struct LoraFingerprint {
978    path_hash: u64,
979    scale_bits: u64,
980}
981
982impl LoraFingerprint {
983    fn from_lora_weight(lora: &mold_core::LoraWeight) -> Self {
984        Self {
985            path_hash: lora_path_hash(&lora.path),
986            scale_bits: lora.scale.to_bits(),
987        }
988    }
989}
990
991/// Fingerprint of an ordered LoRA stack. Equality is order-sensitive —
992/// `[A, B]` and `[B, A]` produce identical numerical results in theory
993/// (delta sums commute) but the wrapper still considers them distinct
994/// because the user-facing intent is order-driven (e.g. style vs
995/// override) and the cost of one redundant rebuild is small.
996fn fingerprint_stack(loras: &[mold_core::LoraWeight]) -> Vec<LoraFingerprint> {
997    loras
998        .iter()
999        .map(LoraFingerprint::from_lora_weight)
1000        .collect()
1001}
1002
1003/// FLUX inference engine backed by candle.
1004pub struct FluxEngine {
1005    base: EngineBase<LoadedFlux>,
1006    /// Optional explicit override for is_schnell; if None, auto-detect from transformer filename.
1007    is_schnell_override: Option<bool>,
1008    /// T5 variant preference: None/"auto" = auto-select, "fp16" = force FP16, "q8"/"q5"/etc = specific quantized.
1009    t5_variant: Option<String>,
1010    prompt_cache: Mutex<LruCache<String, CachedTensorPair>>,
1011    /// Cached result of FP8 safetensors probe (None = not yet checked).
1012    transformer_is_fp8: Option<bool>,
1013    /// Cached resolved transformer path (GGUF cache for FP8, or original path).
1014    /// Avoids re-computing the cache key (file I/O) on every sequential generation.
1015    cached_transformer_path: Option<PathBuf>,
1016    /// Force block-level offloading (--offload / MOLD_OFFLOAD=1).
1017    offload: bool,
1018    /// Fingerprint of the currently applied LoRA (None = no LoRA baked in).
1019    /// Empty when no LoRAs are active. Order-sensitive: changing the
1020    /// stack triggers a transformer rebuild on the next generate.
1021    active_lora: Vec<LoraFingerprint>,
1022    /// CPU-resident cache of pre-computed LoRA deltas, shared across transformer rebuilds.
1023    lora_delta_cache: Arc<Mutex<super::lora::LoraDeltaCache>>,
1024    /// Optional shared tokenizer pool for cross-engine caching.
1025    shared_pool: Option<Arc<Mutex<crate::shared_pool::SharedPool>>>,
1026    /// Per-request placement override. Set at the start of `generate()`,
1027    /// cleared on exit. `None` preserves the existing VRAM-aware auto logic.
1028    pending_placement: Option<mold_core::types::DevicePlacement>,
1029}
1030
1031impl FluxEngine {
1032    /// Create a new FluxEngine. Does not load models until `load()` is called.
1033    /// `is_schnell_override` lets callers explicitly set the scheduler family.
1034    /// `t5_variant` controls T5 encoder selection: None/"auto" = VRAM-based auto-select,
1035    /// "fp16" = force FP16, "q8"/"q5"/etc = specific quantized variant.
1036    #[allow(clippy::too_many_arguments)]
1037    pub fn new(
1038        model_name: String,
1039        paths: ModelPaths,
1040        is_schnell_override: Option<bool>,
1041        t5_variant: Option<String>,
1042        load_strategy: LoadStrategy,
1043        gpu_ordinal: usize,
1044        offload: bool,
1045        shared_pool: Option<Arc<Mutex<crate::shared_pool::SharedPool>>>,
1046    ) -> Self {
1047        Self {
1048            base: EngineBase::new(model_name, paths, load_strategy, gpu_ordinal),
1049            is_schnell_override,
1050            t5_variant,
1051            prompt_cache: Mutex::new(LruCache::new(DEFAULT_PROMPT_CACHE_CAPACITY)),
1052            transformer_is_fp8: None,
1053            cached_transformer_path: None,
1054            offload,
1055            active_lora: Vec::new(),
1056            lora_delta_cache: Arc::new(Mutex::new(super::lora::LoraDeltaCache::new())),
1057            shared_pool,
1058            pending_placement: None,
1059        }
1060    }
1061
1062    /// Return the LoRA delta cache handle, or `None` when disabled via
1063    /// `MOLD_FLUX_DELTA_CACHE=0`. The cache stores CPU-resident F32 delta
1064    /// tensors for every LoRA-touched layer; on a typical FLUX LoRA that's
1065    /// ~25 GB of standing memory which dominates host RAM use during Q8+LoRA
1066    /// rebuilds. Disabling forces a sub-second `B@A·scale` recompute on the
1067    /// next rebuild, which is cheap on GPU.
1068    fn lora_delta_cache_handle(&self) -> Option<Arc<Mutex<super::lora::LoraDeltaCache>>> {
1069        if std::env::var("MOLD_FLUX_DELTA_CACHE")
1070            .map(|v| v == "0")
1071            .unwrap_or(false)
1072        {
1073            None
1074        } else {
1075            Some(self.lora_delta_cache.clone())
1076        }
1077    }
1078
1079    /// Try to get a cached tokenizer from the shared pool.
1080    fn get_cached_tokenizer(&self, path: &std::path::Path) -> Option<Arc<tokenizers::Tokenizer>> {
1081        let pool = self.shared_pool.as_ref()?;
1082        let pool = pool.lock().unwrap();
1083        pool.get_tokenizer(&path.to_string_lossy())
1084    }
1085
1086    /// Store a tokenizer in the shared pool.
1087    fn cache_tokenizer(&self, path: &std::path::Path, tokenizer: Arc<tokenizers::Tokenizer>) {
1088        if let Some(ref pool) = self.shared_pool {
1089            let mut pool = pool.lock().unwrap();
1090            pool.insert_tokenizer(path.to_string_lossy().into_owned(), tokenizer);
1091        }
1092    }
1093
1094    /// Load VAE weights through the shared CPU tensor cache when available.
1095    fn load_vae_var_builder<'a>(
1096        &self,
1097        dtype: DType,
1098        device: &Device,
1099        component: &str,
1100    ) -> Result<VarBuilder<'a>> {
1101        if let Some(pool) = &self.shared_pool {
1102            let cached = pool
1103                .lock()
1104                .unwrap()
1105                .load_cpu_tensors(std::slice::from_ref(&self.base.paths.vae))?;
1106            let vb = crate::encoders::park::varbuilder_from_parked(cached.as_ref(), dtype, device);
1107            return Ok(flux_vae_var_builder(vb));
1108        }
1109
1110        let vb = crate::weight_loader::load_safetensors_with_progress(
1111            std::slice::from_ref(&self.base.paths.vae),
1112            dtype,
1113            device,
1114            component,
1115            &self.base.progress,
1116        )?;
1117        Ok(flux_vae_var_builder(vb))
1118    }
1119
1120    fn get_cached_safetensors(&self, path: &Path) -> Result<Option<Arc<HashMap<String, Tensor>>>> {
1121        let Some(pool) = &self.shared_pool else {
1122            return Ok(None);
1123        };
1124        let paths = [path];
1125        pool.lock().unwrap().load_safetensors_cpu_tensors(&paths)
1126    }
1127
1128    fn restore_prompt_cache(
1129        progress: &ProgressReporter,
1130        prompt_cache: &Mutex<LruCache<String, CachedTensorPair>>,
1131        prompt: &str,
1132        device: &Device,
1133        dtype: DType,
1134    ) -> Result<Option<(candle_core::Tensor, candle_core::Tensor)>> {
1135        let restored =
1136            restore_cached_tensor_pair(prompt_cache, &prompt_text_key(prompt), device, dtype)?;
1137        let Some(restored) = restored else {
1138            return Ok(None);
1139        };
1140        progress.cache_hit("prompt conditioning");
1141        Ok(Some(restored))
1142    }
1143
1144    fn store_prompt_cache(
1145        prompt_cache: &Mutex<LruCache<String, CachedTensorPair>>,
1146        prompt: &str,
1147        t5_emb: &candle_core::Tensor,
1148        clip_emb: &candle_core::Tensor,
1149    ) -> Result<()> {
1150        store_cached_tensor_pair(prompt_cache, prompt_text_key(prompt), t5_emb, clip_emb)
1151    }
1152}
1153
1154/// Move a conditioning tensor to host RAM if it currently lives on GPU.
1155///
1156/// ComfyUI keeps text-encoder outputs on CPU between encode and denoise so the
1157/// transformer load and LoRA merge see ~50–200 MB more headroom. mirroring
1158/// that here: after `t5.encode(...)` / `clip.encode(...)` we call this, then
1159/// move the tensor back to GPU only at `State::new` time inside the denoise
1160/// loop. Idempotent — when the encoder already produced a CPU tensor (the
1161/// GGUF / Q8 dequant path) this is a cheap pass-through with no copy.
1162pub(crate) fn park_cond_to_cpu(tensor: &candle_core::Tensor) -> Result<candle_core::Tensor> {
1163    if tensor.device().is_cpu() {
1164        return Ok(tensor.clone());
1165    }
1166    Ok(tensor.to_device(&Device::Cpu)?)
1167}
1168
1169impl FluxEngine {
1170    /// Detect is_schnell from override, model name, or transformer filename.
1171    fn detect_is_schnell(&self) -> bool {
1172        self.is_schnell_override.unwrap_or_else(|| {
1173            self.base.model_name.contains("schnell")
1174                || self
1175                    .base
1176                    .paths
1177                    .transformer
1178                    .file_name()
1179                    .and_then(|n| n.to_str())
1180                    .map(|n| n.to_ascii_lowercase().contains("schnell"))
1181                    .unwrap_or(false)
1182        })
1183    }
1184
1185    /// Detect if the transformer is quantized (GGUF).
1186    /// Check if the transformer is FP8 safetensors, caching the result so the
1187    /// file is only probed once (not on every `generate_sequential` call).
1188    fn check_transformer_is_fp8(&mut self, is_quantized: bool) -> bool {
1189        if let Some(cached) = self.transformer_is_fp8 {
1190            return cached;
1191        }
1192        let result = !is_quantized
1193            && flux_safetensors_transformer_is_fp8(&self.base.paths.transformer).unwrap_or(false);
1194        self.transformer_is_fp8 = Some(result);
1195        result
1196    }
1197
1198    fn detect_is_quantized(&self) -> bool {
1199        self.base
1200            .paths
1201            .transformer
1202            .extension()
1203            .and_then(|e| e.to_str())
1204            .map(|e| e.eq_ignore_ascii_case("gguf"))
1205            .unwrap_or(false)
1206    }
1207
1208    /// Validate that all required paths exist.
1209    fn validate_paths(
1210        &self,
1211    ) -> Result<(
1212        std::path::PathBuf,
1213        std::path::PathBuf,
1214        std::path::PathBuf,
1215        std::path::PathBuf,
1216    )> {
1217        let t5_encoder_path = self
1218            .base
1219            .paths
1220            .t5_encoder
1221            .as_ref()
1222            .ok_or_else(|| anyhow::anyhow!("T5 encoder path required for FLUX models"))?
1223            .clone();
1224        let t5_tokenizer_path = self
1225            .base
1226            .paths
1227            .t5_tokenizer
1228            .as_ref()
1229            .ok_or_else(|| anyhow::anyhow!("T5 tokenizer path required for FLUX models"))?
1230            .clone();
1231        let clip_encoder_path = self
1232            .base
1233            .paths
1234            .clip_encoder
1235            .as_ref()
1236            .ok_or_else(|| anyhow::anyhow!("CLIP encoder path required for FLUX models"))?
1237            .clone();
1238        let clip_tokenizer_path = self
1239            .base
1240            .paths
1241            .clip_tokenizer
1242            .as_ref()
1243            .ok_or_else(|| anyhow::anyhow!("CLIP tokenizer path required for FLUX models"))?
1244            .clone();
1245
1246        for (label, path) in [
1247            ("transformer", &self.base.paths.transformer),
1248            ("vae", &self.base.paths.vae),
1249            ("t5_encoder", &t5_encoder_path),
1250            ("clip_encoder", &clip_encoder_path),
1251            ("t5_tokenizer", &t5_tokenizer_path),
1252            ("clip_tokenizer", &clip_tokenizer_path),
1253        ] {
1254            if !path.exists() {
1255                bail!("{label} file not found: {}", path.display());
1256            }
1257        }
1258
1259        Ok((
1260            t5_encoder_path,
1261            t5_tokenizer_path,
1262            clip_encoder_path,
1263            clip_tokenizer_path,
1264        ))
1265    }
1266
1267    /// Load all model components into GPU memory (Eager mode).
1268    ///
1269    /// On error, `self.base.loaded` remains `None` — all components are assembled into
1270    /// local variables and only stored in `self.base.loaded` on success, so partial loads
1271    /// cannot leave the engine in an inconsistent state.
1272    pub fn load(&mut self) -> Result<()> {
1273        self.active_lora = Vec::new();
1274        if self.base.loaded.is_some() {
1275            return Ok(());
1276        }
1277
1278        // Sequential/offloaded mode defers loading to generate_sequential().
1279        // The offloaded BF16 transformer is built per request after prompt
1280        // encoding; eager preload would put the full transformer on GPU and
1281        // bypass block streaming.
1282        if self.defers_eager_load() {
1283            return Ok(());
1284        }
1285
1286        let is_schnell = self.detect_is_schnell();
1287        tracing::info!(model = %self.base.model_name, "loading FLUX model components...");
1288
1289        let (t5_encoder_path, t5_tokenizer_path, clip_encoder_path, clip_tokenizer_path) =
1290            self.validate_paths()?;
1291
1292        let cpu = Device::Cpu;
1293        let transformer_ref = effective_device_ref(
1294            self.pending_placement.as_ref(),
1295            |adv| Some(adv.transformer),
1296            false,
1297        );
1298        let device = crate::device::resolve_device(Some(transformer_ref), || {
1299            crate::device::create_device(self.base.gpu_ordinal, &self.base.progress)
1300        })?;
1301        let mut is_quantized = self.detect_is_quantized();
1302        let transformer_is_fp8 = self.check_transformer_is_fp8(is_quantized);
1303
1304        // FP8 safetensors → Q8 GGUF cache: candle lacks native FP8 compute and
1305        // expanding to F16 doubles VRAM (OOM on 24 GB). Q8 GGUF keeps the model
1306        // compact (~12 GB) and uses candle's efficient quantized matmul.
1307        let transformer_path = if transformer_is_fp8 {
1308            let p = ensure_fp8_gguf_cache(&self.base.paths.transformer, &self.base.progress)?;
1309            is_quantized = true;
1310            p
1311        } else {
1312            self.base.paths.transformer.clone()
1313        };
1314
1315        // Patch city96-format GGUFs missing embedding layers (img_in, time_in, etc.)
1316        let transformer_path = if is_quantized {
1317            ensure_gguf_embeddings(&transformer_path, is_schnell, &self.base.progress, None)?
1318        } else {
1319            transformer_path
1320        };
1321
1322        let gpu_dtype = flux_runtime_dtype(device.is_cuda(), is_quantized, false);
1323
1324        tracing::info!("GPU device: {:?}, GPU dtype: {:?}", device, gpu_dtype);
1325
1326        // --- Load FLUX transformer + VAE on GPU first (variable size) ---
1327        // This must happen before T5/CLIP so we can measure remaining VRAM.
1328
1329        // Check if full-precision transformer fits in VRAM before attempting load.
1330        if !is_quantized {
1331            let xformer_size = std::fs::metadata(&transformer_path)
1332                .map(|m| m.len())
1333                .unwrap_or(0);
1334            // Budget decision: subtract the OS / cuBLAS reserve so we don't
1335            // promise space the next allocator call cannot deliver.
1336            let free = usable_free_vram_bytes(self.base.gpu_ordinal).unwrap_or(0);
1337            if free > 0 && xformer_size > free {
1338                bail!(
1339                    "transformer ({:.1} GB) exceeds available VRAM ({:.1} GB) — \
1340                     use a quantized model (q8/q4) instead of full-precision for this GPU",
1341                    xformer_size as f64 / 1e9,
1342                    free as f64 / 1e9,
1343                );
1344            }
1345        }
1346
1347        let flux_cfg = if is_schnell {
1348            flux::model::Config::schnell()
1349        } else {
1350            flux::model::Config::dev()
1351        };
1352
1353        let xformer_label = if is_quantized {
1354            "Loading FLUX transformer (GPU, quantized)"
1355        } else {
1356            "Loading FLUX transformer (GPU, BF16)"
1357        };
1358        self.base.progress.stage_start(xformer_label);
1359        let xformer_stage = Instant::now();
1360        tracing::info!(
1361            path = %transformer_path.display(),
1362            quantized = is_quantized,
1363            "loading FLUX transformer on GPU..."
1364        );
1365
1366        let flux_model = if is_quantized {
1367            let vb = quantized_var_builder::VarBuilder::from_gguf(&transformer_path, &device)?;
1368            FluxTransformer::Quantized(flux::quantized_model::Flux::new(&flux_cfg, vb)?)
1369        } else {
1370            let flux_vb = flux_transformer_var_builder(flux_safetensors_var_builder(
1371                &transformer_path,
1372                gpu_dtype,
1373                &device,
1374                "FLUX transformer",
1375                &self.base.progress,
1376            )?);
1377            FluxTransformer::BF16(flux::model::Flux::new(&flux_cfg, flux_vb)?)
1378        };
1379        self.base
1380            .progress
1381            .stage_done(xformer_label, xformer_stage.elapsed());
1382        tracing::info!("FLUX transformer loaded on GPU");
1383
1384        // Load VAE on GPU (small, ~300MB)
1385        // Tier 2: honor `advanced.vae` override.
1386        let vae_ref =
1387            effective_device_ref(self.pending_placement.as_ref(), |adv| Some(adv.vae), false);
1388        let vae_device = crate::device::resolve_device(Some(vae_ref), || Ok(device.clone()))?;
1389        self.base.progress.stage_start("Loading VAE (GPU)");
1390        let vae_stage = Instant::now();
1391        tracing::info!(path = %self.base.paths.vae.display(), "loading VAE on GPU...");
1392        // Resolve VAE precision once at load time — see LoadedFlux::vae_dtype.
1393        let vae_dtype = crate::device::resolve_vae_dtype(gpu_dtype);
1394        let vae_vb = self.load_vae_var_builder(vae_dtype, &vae_device, "VAE")?;
1395        let vae_cfg = if is_schnell {
1396            flux::autoencoder::Config::schnell()
1397        } else {
1398            flux::autoencoder::Config::dev()
1399        };
1400        let vae = flux::autoencoder::AutoEncoder::new(&vae_cfg, vae_vb)?;
1401        self.base
1402            .progress
1403            .stage_done("Loading VAE (GPU)", vae_stage.elapsed());
1404        tracing::info!("VAE loaded on GPU");
1405
1406        // --- Decide where to place T5 and CLIP based on remaining VRAM ---
1407        // Log the raw driver reading (matches `nvidia-smi`) but pass the
1408        // reserve-adjusted budget to variant resolution so quantized
1409        // encoders aren't picked when their footprint would push past the
1410        // OS / cuBLAS workspace headroom.
1411        let free_raw = free_vram_bytes(self.base.gpu_ordinal).unwrap_or(0);
1412        let free = usable_free_vram_bytes(self.base.gpu_ordinal).unwrap_or(0);
1413        if free_raw > 0 {
1414            self.base.progress.info(&format!(
1415                "Free VRAM after transformer+VAE: {}",
1416                fmt_gb(free_raw)
1417            ));
1418            tracing::info!(
1419                free_vram = free_raw,
1420                free_vram_usable = free,
1421                "free VRAM after loading transformer + VAE"
1422            );
1423        }
1424
1425        // --- T5 encoder: auto-select variant based on VRAM or explicit preference ---
1426        self.base.progress.stage_start("Selecting T5 encoder");
1427        let t5_resolve_start = Instant::now();
1428        let t5_preference = self.t5_variant.as_deref();
1429        let (resolved_t5_path, t5_on_gpu, _t5_auto_device_label) =
1430            crate::encoders::variant_resolution::resolve_t5_variant(
1431                &self.base.progress,
1432                t5_preference,
1433                &device,
1434                free,
1435                &t5_encoder_path,
1436            )?;
1437        self.base
1438            .progress
1439            .stage_done("Selecting T5 encoder", t5_resolve_start.elapsed());
1440        // Tier 2 (if `advanced.t5` populated) overrides Tier 1 text_encoders group knob.
1441        let t5_ref = effective_device_ref(self.pending_placement.as_ref(), |adv| adv.t5, true);
1442        let auto_t5_device = if t5_on_gpu {
1443            device.clone()
1444        } else {
1445            cpu.clone()
1446        };
1447        let t5_device_owned =
1448            crate::device::resolve_device(Some(t5_ref), || Ok(auto_t5_device.clone()))?;
1449        let t5_device = &t5_device_owned;
1450        let t5_on_gpu = !t5_device.is_cpu();
1451        let t5_device_label = if t5_on_gpu { "GPU" } else { "CPU" };
1452        let t5_dtype = if t5_on_gpu { gpu_dtype } else { DType::F32 };
1453
1454        // Load T5 encoder
1455        let t5_stage_label = format!("Loading T5 encoder ({t5_device_label})");
1456        self.base.progress.stage_start(&t5_stage_label);
1457        let t5_stage = Instant::now();
1458        tracing::info!(
1459            path = %resolved_t5_path.display(),
1460            device = %t5_device_label,
1461            "loading T5 encoder..."
1462        );
1463        let cached_t5_tok = self.get_cached_tokenizer(&t5_tokenizer_path);
1464        let cached_t5_tensors = self.get_cached_safetensors(&resolved_t5_path)?;
1465        let t5 = encoders::t5::T5Encoder::load_with_tokenizer_and_tensors(
1466            &resolved_t5_path,
1467            &t5_tokenizer_path,
1468            t5_device,
1469            t5_dtype,
1470            &self.base.progress,
1471            cached_t5_tok,
1472            cached_t5_tensors,
1473        )?;
1474        self.cache_tokenizer(&t5_tokenizer_path, t5.tokenizer_arc());
1475        self.base
1476            .progress
1477            .stage_done(&t5_stage_label, t5_stage.elapsed());
1478        tracing::info!(device = %t5_device_label, "T5 encoder loaded");
1479
1480        // Re-check VRAM after T5 (it may have consumed GPU memory). Budget
1481        // decision → reserve-adjusted reading.
1482        let free_after_t5 = usable_free_vram_bytes(self.base.gpu_ordinal).unwrap_or(0);
1483        let clip_on_gpu = should_use_gpu(
1484            device.is_cuda(),
1485            device.is_metal(),
1486            free_after_t5,
1487            CLIP_VRAM_THRESHOLD,
1488        );
1489        let clip_ref =
1490            effective_device_ref(self.pending_placement.as_ref(), |adv| adv.clip_l, true);
1491        let auto_clip_device = if clip_on_gpu {
1492            device.clone()
1493        } else {
1494            cpu.clone()
1495        };
1496        let clip_device_owned =
1497            crate::device::resolve_device(Some(clip_ref), || Ok(auto_clip_device.clone()))?;
1498        let clip_device = &clip_device_owned;
1499        let clip_on_gpu = !clip_device.is_cpu();
1500        let clip_dtype = if clip_on_gpu { gpu_dtype } else { DType::F32 };
1501        let clip_device_label = if clip_on_gpu { "GPU" } else { "CPU" };
1502
1503        // Load CLIP encoder
1504        let clip_stage_label = format!("Loading CLIP encoder ({clip_device_label})");
1505        self.base.progress.stage_start(&clip_stage_label);
1506        let clip_stage = Instant::now();
1507        tracing::info!(
1508            path = %clip_encoder_path.display(),
1509            device = clip_device_label,
1510            "loading CLIP encoder..."
1511        );
1512        let cached_clip_tok = self.get_cached_tokenizer(&clip_tokenizer_path);
1513        let cached_clip_tensors = self.get_cached_safetensors(&clip_encoder_path)?;
1514        let clip = encoders::clip::ClipEncoder::load_with_tokenizer_and_tensors(
1515            &clip_encoder_path,
1516            &clip_tokenizer_path,
1517            clip_device,
1518            clip_dtype,
1519            &self.base.progress,
1520            cached_clip_tok,
1521            cached_clip_tensors,
1522        )?;
1523        self.cache_tokenizer(&clip_tokenizer_path, clip.tokenizer_arc());
1524        self.base
1525            .progress
1526            .stage_done(&clip_stage_label, clip_stage.elapsed());
1527        tracing::info!(device = clip_device_label, "CLIP encoder loaded");
1528
1529        self.base.loaded = Some(LoadedFlux {
1530            flux_model: Some(flux_model),
1531            t5,
1532            clip,
1533            vae,
1534            device,
1535            dtype: gpu_dtype,
1536            vae_dtype,
1537            is_schnell,
1538            is_quantized,
1539            transformer_path,
1540            t5_encoder_path: resolved_t5_path,
1541        });
1542
1543        tracing::info!(model = %self.base.model_name, "all model components loaded successfully");
1544        Ok(())
1545    }
1546
1547    /// Generate an image using sequential loading strategy.
1548    ///
1549    /// Loads components one at a time and drops them when done to minimize peak memory:
1550    /// 1. Load T5 → encode → drop T5
1551    /// 2. Load CLIP → encode → drop CLIP
1552    /// 3. Load transformer + VAE → denoise → drop transformer
1553    /// 4. VAE decode → drop VAE
1554    ///
1555    /// Peak memory: max(T5_size, transformer_size + VAE_size) instead of sum(all).
1556    fn generate_sequential(&mut self, req: &GenerateRequest) -> Result<GenerateResponse> {
1557        let is_schnell = self.detect_is_schnell();
1558        let mut is_quantized = self.detect_is_quantized();
1559
1560        let (t5_encoder_path, t5_tokenizer_path, clip_encoder_path, clip_tokenizer_path) =
1561            self.validate_paths()?;
1562
1563        // Check memory budget
1564        if let Some(warning) = check_memory_budget(&self.base.paths, LoadStrategy::Sequential) {
1565            self.base.progress.info(&warning);
1566        }
1567
1568        let transformer_ref = effective_device_ref(
1569            self.pending_placement.as_ref(),
1570            |adv| Some(adv.transformer),
1571            false,
1572        );
1573        let device = crate::device::resolve_device(Some(transformer_ref), || {
1574            crate::device::create_device(self.base.gpu_ordinal, &self.base.progress)
1575        })?;
1576
1577        // Use cached transformer path to avoid file I/O on every sequential call.
1578        let transformer_path = if let Some(ref cached) = self.cached_transformer_path {
1579            if cached
1580                .extension()
1581                .and_then(|e| e.to_str())
1582                .map(|e| e.eq_ignore_ascii_case("gguf"))
1583                .unwrap_or(false)
1584            {
1585                is_quantized = true;
1586            }
1587            cached.clone()
1588        } else {
1589            let transformer_is_fp8 = self.check_transformer_is_fp8(is_quantized);
1590            let p = if transformer_is_fp8 {
1591                let p = ensure_fp8_gguf_cache(&self.base.paths.transformer, &self.base.progress)?;
1592                is_quantized = true;
1593                p
1594            } else {
1595                self.base.paths.transformer.clone()
1596            };
1597            // Patch city96-format GGUFs missing embedding layers
1598            let p = if is_quantized {
1599                ensure_gguf_embeddings(&p, is_schnell, &self.base.progress, None)?
1600            } else {
1601                p
1602            };
1603            self.cached_transformer_path = Some(p.clone());
1604            p
1605        };
1606
1607        let gpu_dtype = flux_runtime_dtype(device.is_cuda(), is_quantized, false);
1608
1609        let start = Instant::now();
1610        let seed = req.seed.unwrap_or_else(rand_seed);
1611
1612        let width = req.width as usize;
1613        let height = req.height as usize;
1614
1615        tracing::info!(
1616            prompt = %req.prompt,
1617            seed, width, height,
1618            steps = req.steps,
1619            "starting sequential FLUX generation"
1620        );
1621
1622        self.base
1623            .progress
1624            .info("Using sequential loading (load-use-drop) to minimize peak memory");
1625
1626        let (t5_emb, clip_emb) = if let Some((t5_emb, clip_emb)) = Self::restore_prompt_cache(
1627            &self.base.progress,
1628            &self.prompt_cache,
1629            &req.prompt,
1630            &device,
1631            gpu_dtype,
1632        )? {
1633            (t5_emb, clip_emb)
1634        } else {
1635            // --- Phase 1: T5 encoding ---
1636            // Reserve-adjusted reading drives the variant choice.
1637            let free = usable_free_vram_bytes(self.base.gpu_ordinal).unwrap_or(0);
1638            self.base.progress.stage_start("Selecting T5 encoder");
1639            let t5_resolve_start = Instant::now();
1640            let t5_preference = self.t5_variant.as_deref();
1641            let (resolved_t5_path, t5_on_gpu, _t5_auto_device_label) =
1642                crate::encoders::variant_resolution::resolve_t5_variant(
1643                    &self.base.progress,
1644                    t5_preference,
1645                    &device,
1646                    free,
1647                    &t5_encoder_path,
1648                )?;
1649            self.base
1650                .progress
1651                .stage_done("Selecting T5 encoder", t5_resolve_start.elapsed());
1652
1653            let t5_ref = effective_device_ref(self.pending_placement.as_ref(), |adv| adv.t5, true);
1654            let auto_t5_device = if t5_on_gpu {
1655                device.clone()
1656            } else {
1657                Device::Cpu
1658            };
1659            let t5_device_owned =
1660                crate::device::resolve_device(Some(t5_ref), || Ok(auto_t5_device.clone()))?;
1661            let t5_device = &t5_device_owned;
1662            let t5_on_gpu = !t5_device.is_cpu();
1663            let t5_device_label = if t5_on_gpu { "GPU" } else { "CPU" };
1664            let t5_dtype = if t5_on_gpu { gpu_dtype } else { DType::F32 };
1665
1666            let t5_size = std::fs::metadata(&resolved_t5_path)
1667                .map(|m| m.len())
1668                .unwrap_or(0);
1669            // T5 activations: ~256 MB workspace (floor) — small relative to
1670            // the 9 GB encoder weights and only resident during encoding.
1671            let t5_activation_budget = crate::device::activation_bytes(
1672                req.width,
1673                req.height,
1674                1,
1675                crate::device::dtype_bytes(t5_dtype),
1676                crate::device::ActivationFamily::SmallTransformer,
1677            );
1678            preflight_memory_check("T5 encoder", t5_size, t5_activation_budget)?;
1679            if let Some(status) = memory_status_string() {
1680                self.base.progress.info(&status);
1681            }
1682
1683            let t5_stage_label = format!("Loading T5 encoder ({t5_device_label})");
1684            self.base.progress.stage_start(&t5_stage_label);
1685            let t5_stage = Instant::now();
1686            let cached_t5_tok = self.get_cached_tokenizer(&t5_tokenizer_path);
1687            let cached_t5_tensors = self.get_cached_safetensors(&resolved_t5_path)?;
1688            let mut t5 = encoders::t5::T5Encoder::load_with_tokenizer_and_tensors(
1689                &resolved_t5_path,
1690                &t5_tokenizer_path,
1691                t5_device,
1692                t5_dtype,
1693                &self.base.progress,
1694                cached_t5_tok,
1695                cached_t5_tensors,
1696            )?;
1697            self.cache_tokenizer(&t5_tokenizer_path, t5.tokenizer_arc());
1698            self.base
1699                .progress
1700                .stage_done(&t5_stage_label, t5_stage.elapsed());
1701
1702            self.base.progress.stage_start("Encoding prompt (T5)");
1703            let encode_t5 = Instant::now();
1704            // Park to CPU immediately so the transformer load + LoRA merge
1705            // window (next 200–500 ms) doesn't have to budget for ~12 MB of
1706            // T5 output sitting on GPU. Idempotent on the GGUF path where T5
1707            // already produces CPU tensors.
1708            let t5_emb = park_cond_to_cpu(&t5.encode(&req.prompt, &device, gpu_dtype)?)?;
1709            self.base
1710                .progress
1711                .stage_done("Encoding prompt (T5)", encode_t5.elapsed());
1712
1713            drop(t5);
1714            self.base.progress.info("Freed T5 encoder");
1715            tracing::info!("T5 encoder dropped (sequential mode)");
1716
1717            // --- Phase 2: CLIP encoding ---
1718            // Reserve-adjusted reading — should_use_gpu must respect the
1719            // same OS / cuBLAS workspace headroom as the T5 placement above.
1720            let free_for_clip = usable_free_vram_bytes(self.base.gpu_ordinal).unwrap_or(0);
1721            let clip_on_gpu = should_use_gpu(
1722                device.is_cuda(),
1723                device.is_metal(),
1724                free_for_clip,
1725                CLIP_VRAM_THRESHOLD,
1726            );
1727            let clip_ref =
1728                effective_device_ref(self.pending_placement.as_ref(), |adv| adv.clip_l, true);
1729            let auto_clip_device = if clip_on_gpu {
1730                device.clone()
1731            } else {
1732                Device::Cpu
1733            };
1734            let clip_device_owned =
1735                crate::device::resolve_device(Some(clip_ref), || Ok(auto_clip_device.clone()))?;
1736            let clip_device = &clip_device_owned;
1737            let clip_on_gpu = !clip_device.is_cpu();
1738            let clip_dtype = if clip_on_gpu { gpu_dtype } else { DType::F32 };
1739            let clip_device_label = if clip_on_gpu { "GPU" } else { "CPU" };
1740
1741            let clip_stage_label = format!("Loading CLIP encoder ({clip_device_label})");
1742            self.base.progress.stage_start(&clip_stage_label);
1743            let clip_stage = Instant::now();
1744            let cached_clip_tok = self.get_cached_tokenizer(&clip_tokenizer_path);
1745            let cached_clip_tensors = self.get_cached_safetensors(&clip_encoder_path)?;
1746            let clip = encoders::clip::ClipEncoder::load_with_tokenizer_and_tensors(
1747                &clip_encoder_path,
1748                &clip_tokenizer_path,
1749                clip_device,
1750                clip_dtype,
1751                &self.base.progress,
1752                cached_clip_tok,
1753                cached_clip_tensors,
1754            )?;
1755            self.cache_tokenizer(&clip_tokenizer_path, clip.tokenizer_arc());
1756            self.base
1757                .progress
1758                .stage_done(&clip_stage_label, clip_stage.elapsed());
1759
1760            self.base.progress.stage_start("Encoding prompt (CLIP)");
1761            let encode_clip = Instant::now();
1762            // Park to CPU for the same reason as T5 above — keeps the
1763            // TE→transformer transition window from carrying GPU residency
1764            // we don't need.
1765            let clip_emb = {
1766                let mut clip = clip;
1767                park_cond_to_cpu(&clip.encode(&req.prompt, &device, gpu_dtype)?)?
1768            };
1769            self.base
1770                .progress
1771                .stage_done("Encoding prompt (CLIP)", encode_clip.elapsed());
1772
1773            self.base.progress.info("Freed CLIP encoder");
1774            tracing::info!("CLIP encoder dropped (sequential mode)");
1775
1776            // Cache stores via `CachedTensor::from_tensor`, which itself
1777            // moves to CPU; passing CPU tensors here avoids an unnecessary
1778            // round-trip on the GGUF path.
1779            Self::store_prompt_cache(&self.prompt_cache, &req.prompt, &t5_emb, &clip_emb)?;
1780            (t5_emb, clip_emb)
1781        };
1782
1783        // Synchronize to ensure freed T5/CLIP VRAM is reclaimed before
1784        // loading the transformer (critical for FP8 models that expand to F16).
1785        device.synchronize()?;
1786
1787        // --- Phase 3: Load transformer, denoise ---
1788        let xformer_size = std::fs::metadata(&transformer_path)
1789            .map(|m| m.len())
1790            .unwrap_or(0);
1791        let vae_file_size = std::fs::metadata(&self.base.paths.vae)
1792            .map(|m| m.len())
1793            .unwrap_or(0);
1794
1795        // LoRA + GGUF: supported via selective dequantization.
1796        // LoRA-affected layers are dequantized to F32 on CPU, patched, then
1797        // re-quantized back to the original GGML dtype. Non-LoRA tensors are
1798        // left quantized and untouched.
1799
1800        // Per-request activation budget — replaces the fixed 3 GB
1801        // INFERENCE_HEADROOM. Scales with resolution and dtype, so a 768²
1802        // generation isn't false-offloaded on a 16 GB card while a 2048²
1803        // generation isn't under-budgeted.
1804        let activation_budget = crate::device::activation_bytes(
1805            req.width,
1806            req.height,
1807            1, // FLUX is guidance-distilled — single forward per step.
1808            crate::device::dtype_bytes(gpu_dtype),
1809            crate::device::ActivationFamily::FluxDit,
1810        );
1811
1812        // Determine if block-level offloading should be used.
1813        let use_offload = if !is_quantized {
1814            // Reserve-adjusted reading: subtract the OS reserve before passing
1815            // to `should_offload`, which budgets transformer + activation
1816            // headroom against this number.
1817            let free = usable_free_vram_bytes(self.base.gpu_ordinal).unwrap_or(0);
1818            if self.offload || should_offload(xformer_size, free, activation_budget) {
1819                if free > 0 && free < MIN_OFFLOAD_VRAM {
1820                    bail!(
1821                        "GPU only has {:.1} GB free — at least {:.1} GB is required \
1822                         for block-level offloading",
1823                        free as f64 / 1e9,
1824                        MIN_OFFLOAD_VRAM as f64 / 1e9,
1825                    );
1826                }
1827                true
1828            } else if free > 0 && xformer_size > free {
1829                bail!(
1830                    "transformer ({:.1} GB) exceeds available VRAM ({:.1} GB) — \
1831                     use a quantized model (q8/q4) or --offload for block-level streaming",
1832                    xformer_size as f64 / 1e9,
1833                    free as f64 / 1e9,
1834                );
1835            } else {
1836                false
1837            }
1838        } else {
1839            if self.offload {
1840                tracing::warn!(
1841                    "block-level offloading is not supported for quantized models; \
1842                     --offload / MOLD_OFFLOAD=1 will be ignored"
1843                );
1844            }
1845            false
1846        };
1847
1848        // Even when offloading, blocks must still fit in system RAM on unified-memory
1849        // (Metal) hosts — preflight catches machines with insufficient total memory.
1850        if !use_offload || device.is_metal() {
1851            preflight_memory_check(
1852                "FLUX transformer + VAE",
1853                xformer_size + vae_file_size,
1854                activation_budget,
1855            )?;
1856        }
1857        if let Some(status) = memory_status_string() {
1858            self.base.progress.info(&status);
1859        }
1860
1861        let flux_cfg = if is_schnell {
1862            flux::model::Config::schnell()
1863        } else {
1864            flux::model::Config::dev()
1865        };
1866
1867        let active_loras = effective_loras(req);
1868        let has_lora = !active_loras.is_empty();
1869        let xformer_label = if has_lora && use_offload {
1870            "Loading FLUX transformer + LoRA (offloaded)"
1871        } else if has_lora && is_quantized {
1872            "Loading FLUX transformer + LoRA (GPU, quantized + selective deq)"
1873        } else if has_lora {
1874            "Loading FLUX transformer + LoRA (GPU, BF16)"
1875        } else if use_offload {
1876            "Loading FLUX transformer (offloaded, blocks on CPU)"
1877        } else if is_quantized {
1878            "Loading FLUX transformer (GPU, quantized)"
1879        } else {
1880            "Loading FLUX transformer (GPU, BF16)"
1881        };
1882        self.base.progress.stage_start(xformer_label);
1883        let xformer_stage = Instant::now();
1884
1885        let bypass_mode = LoraBypassMode::from_env();
1886        // For the offloaded path, bypass is the obvious win whenever
1887        // LoRAs are active: the legacy merge path runs `B@A·scale` on
1888        // every targeted CPU-resident BF16 tensor and rebuilds the
1889        // ~24 GB block buffer on every LoRA swap. Bypass keeps adapters
1890        // GPU-resident, so a swap is just a registry replace.
1891        let use_offload_bypass = use_offload && has_lora && bypass_mode != LoraBypassMode::Off;
1892
1893        let flux_model = if use_offload {
1894            // Load transformer blocks on CPU. With bypass enabled the
1895            // base weights are loaded *unmodified* (LoRA contributions
1896            // are added at forward time); without bypass we fall back
1897            // to the merge-on-load path.
1898            let cpu_vb: VarBuilder = if has_lora && !use_offload_bypass {
1899                // Legacy LoRA backend: loads from mmap to CPU, patches inline
1900                flux_lora_var_builder(
1901                    &transformer_path,
1902                    &active_loras,
1903                    gpu_dtype,
1904                    &Device::Cpu,
1905                    &self.base.progress,
1906                    self.lora_delta_cache_handle(),
1907                )?
1908            } else {
1909                flux_transformer_var_builder(flux_safetensors_var_builder(
1910                    &transformer_path,
1911                    gpu_dtype,
1912                    &Device::Cpu,
1913                    "FLUX transformer",
1914                    &self.base.progress,
1915                )?)
1916            };
1917            let mut offloaded = crate::flux::offload::OffloadedFluxTransformer::load(
1918                cpu_vb,
1919                &flux_cfg,
1920                &device,
1921                &self.base.progress,
1922            )?;
1923            if use_offload_bypass {
1924                let registry = build_lora_registry(
1925                    &active_loras,
1926                    &flux_cfg,
1927                    &device,
1928                    gpu_dtype,
1929                    &self.base.progress,
1930                )?;
1931                offloaded.set_lora_registry(registry);
1932            }
1933            FluxTransformer::Offloaded(offloaded)
1934        } else if is_quantized && has_lora {
1935            // GGUF + LoRA: bypass-mode keeps base weights untouched and
1936            // applies LoRA deltas at forward time. Saves the
1937            // dequant→merge→requant cycle that previously cost minutes
1938            // and ~95 GB CPU peak per LoRA load on Q8.
1939            let bypass_quantized = bypass_mode != LoraBypassMode::Off;
1940            if bypass_quantized {
1941                let registry = build_lora_registry(
1942                    &active_loras,
1943                    &flux_cfg,
1944                    &device,
1945                    gpu_dtype,
1946                    &self.base.progress,
1947                )?;
1948                let vb = quantized_var_builder::VarBuilder::from_gguf(&transformer_path, &device)?;
1949                FluxTransformer::QuantizedBypass(
1950                    crate::flux::quantized_transformer::QuantizedFluxTransformer::load(
1951                        &flux_cfg,
1952                        vb,
1953                        registry.as_ref(),
1954                        &self.base.progress,
1955                    )?,
1956                )
1957            } else {
1958                // Legacy fallback: dequantize LoRA-affected layers, keep rest quantized.
1959                let vb = flux_gguf_lora_var_builder(
1960                    &transformer_path,
1961                    &active_loras,
1962                    &device,
1963                    &self.base.progress,
1964                    self.lora_delta_cache_handle(),
1965                )?;
1966                FluxTransformer::Quantized(flux::quantized_model::Flux::new(&flux_cfg, vb)?)
1967            }
1968        } else if is_quantized {
1969            let vb = quantized_var_builder::VarBuilder::from_gguf(&transformer_path, &device)?;
1970            FluxTransformer::Quantized(flux::quantized_model::Flux::new(&flux_cfg, vb)?)
1971        } else if has_lora {
1972            // LoRA without offload (GPU has enough VRAM for full model)
1973            let flux_vb = flux_lora_var_builder(
1974                &transformer_path,
1975                &active_loras,
1976                gpu_dtype,
1977                &device,
1978                &self.base.progress,
1979                self.lora_delta_cache_handle(),
1980            )?;
1981            FluxTransformer::BF16(flux::model::Flux::new(&flux_cfg, flux_vb)?)
1982        } else {
1983            let flux_vb = flux_transformer_var_builder(flux_safetensors_var_builder(
1984                &transformer_path,
1985                gpu_dtype,
1986                &device,
1987                "FLUX transformer",
1988                &self.base.progress,
1989            )?);
1990            FluxTransformer::BF16(flux::model::Flux::new(&flux_cfg, flux_vb)?)
1991        };
1992        self.base
1993            .progress
1994            .stage_done(xformer_label, xformer_stage.elapsed());
1995        if let Some(status) = memory_status_string() {
1996            self.base.progress.info(&status);
1997        }
1998
1999        // Generate noise and build state
2000        let noise_dtype = if is_quantized { DType::F32 } else { gpu_dtype };
2001        let latent_h = height / 16 * 2;
2002        let latent_w = width / 16 * 2;
2003        // Pre-compute timestep schedule (needed before mixing for img2img).
2004        // For non-schnell models the schedule depends on image_seq_len which
2005        // we can derive from latent dimensions without the actual tensor.
2006        let image_seq_len = (latent_h / 2) * (latent_w / 2);
2007        let mut timesteps = if is_schnell {
2008            flux::sampling::get_schedule(req.steps as usize, None)
2009        } else {
2010            flux::sampling::get_schedule(req.steps as usize, Some((image_seq_len, 0.5, 1.15)))
2011        };
2012
2013        if req.source_image.is_some() {
2014            let start_index = crate::img2img::img2img_start_index(req.steps as usize, req.strength);
2015            timesteps = timesteps[start_index..].to_vec();
2016            tracing::info!(
2017                strength = req.strength,
2018                start_index,
2019                start_timestep = timesteps[0],
2020                schedule = ?timesteps,
2021                remaining_steps = timesteps.len().saturating_sub(1),
2022                "img2img: truncated schedule from strength"
2023            );
2024        }
2025
2026        // For img2img we need the VAE before denoising (to encode the source image).
2027        // For txt2img we defer VAE loading until after denoising to maximize VRAM
2028        // available for the transformer — critical for FP8 models expanded to F16.
2029        let vae_cfg = if is_schnell {
2030            flux::autoencoder::Config::schnell()
2031        } else {
2032            flux::autoencoder::Config::dev()
2033        };
2034        // Resolve once so the early img2img encode and the later decode load
2035        // the VAE at the same precision; fixes shape mismatch when
2036        // `MOLD_VAE_DTYPE=fp32` would otherwise upgrade only the decode path.
2037        let early_vae_dtype = crate::device::resolve_vae_dtype(gpu_dtype);
2038
2039        let (img, inpaint_ctx, early_vae) = if let Some(ref source_bytes) = req.source_image {
2040            let start_t = timesteps[0];
2041
2042            // Load VAE early for source image encoding
2043            self.base.progress.stage_start("Loading VAE (GPU)");
2044            let vae_stage = Instant::now();
2045            let vae_vb = self.load_vae_var_builder(early_vae_dtype, &device, "VAE")?;
2046            let vae = flux::autoencoder::AutoEncoder::new(&vae_cfg, vae_vb)?;
2047            self.base
2048                .progress
2049                .stage_done("Loading VAE (GPU)", vae_stage.elapsed());
2050
2051            self.base
2052                .progress
2053                .stage_start("Encoding source image (VAE)");
2054            let encode_start = Instant::now();
2055            let source_tensor = crate::img_utils::decode_source_image(
2056                source_bytes,
2057                req.width,
2058                req.height,
2059                crate::img_utils::NormalizeRange::MinusOneToOne,
2060                &device,
2061                early_vae_dtype,
2062            )?;
2063            // FLUX VAE expects pixels in [-1, 1]; encode applies shift/scale internally
2064            let encoded = vae.encode(&source_tensor)?;
2065            self.base
2066                .progress
2067                .stage_done("Encoding source image (VAE)", encode_start.elapsed());
2068
2069            // Flow-matching img2img: interpolate between encoded latents and noise
2070            // at the exact noise level matching the first timestep in the schedule
2071            let noise = crate::engine::seeded_randn(
2072                seed,
2073                &[1, 16, latent_h, latent_w],
2074                &device,
2075                noise_dtype,
2076            )?;
2077            let encoded = encoded.to_dtype(noise_dtype)?;
2078
2079            // Build inpaint context if mask provided
2080            let inpaint_ctx = if let Some(ref mask_bytes) = req.mask_image {
2081                let mask = crate::img_utils::decode_mask_image(
2082                    mask_bytes,
2083                    latent_h,
2084                    latent_w,
2085                    &device,
2086                    noise_dtype,
2087                )?;
2088                Some(crate::img_utils::InpaintContext {
2089                    original_latents: encoded.clone(),
2090                    mask,
2091                    noise: noise.clone(),
2092                })
2093            } else {
2094                None
2095            };
2096
2097            // latent = (1 - t) * encoded + t * noise
2098            // t matches the first schedule timestep, so denoising starts at the correct level
2099            let img = ((&encoded * (1.0 - start_t))? + (&noise * start_t)?)?;
2100            (img, inpaint_ctx, Some(vae))
2101        } else {
2102            let img = crate::engine::seeded_randn(
2103                seed,
2104                &[1, 16, latent_h, latent_w],
2105                &device,
2106                noise_dtype,
2107            )?;
2108            (img, None, None)
2109        };
2110
2111        // Migrate the parked conditioning tensors back to GPU now that the
2112        // transformer load + LoRA merge phase is over. `to_device` on a
2113        // tensor already on `device` is a no-op clone, so the cache-restore
2114        // path (which returns GPU tensors) costs nothing here.
2115        let t5_emb = t5_emb.to_device(&device)?;
2116        let clip_emb = clip_emb.to_device(&device)?;
2117        let (t5_emb_state, clip_emb_state, img_state) = if is_quantized {
2118            (
2119                t5_emb.to_dtype(DType::F32)?,
2120                clip_emb.to_dtype(DType::F32)?,
2121                img.to_dtype(DType::F32)?,
2122            )
2123        } else {
2124            (t5_emb, clip_emb, img)
2125        };
2126
2127        let state = flux::sampling::State::new(&t5_emb_state, &clip_emb_state, &img_state)?;
2128        let inpaint_ctx = inpaint_ctx
2129            .as_ref()
2130            .map(crate::img2img::pack_flux_inpaint_context)
2131            .transpose()?;
2132
2133        let denoise_label = format!("Denoising ({} steps)", timesteps.len().saturating_sub(1));
2134        self.base.progress.stage_start(&denoise_label);
2135        let denoise_start = Instant::now();
2136
2137        let img = flux_model.denoise(
2138            &state.img,
2139            &state.img_ids,
2140            &state.txt,
2141            &state.txt_ids,
2142            &state.vec,
2143            &timesteps,
2144            req.guidance,
2145            &self.base.progress,
2146            inpaint_ctx.as_ref(),
2147        )?;
2148
2149        let img = flux::sampling::unpack(&img, height, width)?;
2150        self.base
2151            .progress
2152            .stage_done(&denoise_label, denoise_start.elapsed());
2153
2154        // Drop transformer + state to free memory for VAE decode
2155        drop(inpaint_ctx);
2156        drop(flux_model);
2157        self.base.progress.info("Freed FLUX transformer");
2158        drop(state);
2159        drop(t5_emb_state);
2160        drop(clip_emb_state);
2161        drop(img_state);
2162        // Synchronize to ensure CUDA frees dropped memory before VAE allocates
2163        device.synchronize()?;
2164        tracing::info!("Transformer dropped (sequential mode), decoding VAE...");
2165
2166        // --- Phase 4: VAE decode ---
2167        // Use VAE from img2img path if already loaded, otherwise load now
2168        // (deferred loading saves ~300MB VRAM during denoising for FP8 models).
2169        // Sequential path resolves MOLD_VAE_DTYPE per request — env changes
2170        // take effect on the next generate() without an engine reload.
2171        let vae_dtype = crate::device::resolve_vae_dtype(gpu_dtype);
2172        let vae = if let Some(vae) = early_vae {
2173            vae
2174        } else {
2175            self.base.progress.stage_start("Loading VAE (GPU)");
2176            let vae_stage = Instant::now();
2177            let vae_vb = self.load_vae_var_builder(vae_dtype, &device, "VAE")?;
2178            let vae = flux::autoencoder::AutoEncoder::new(&vae_cfg, vae_vb)?;
2179            self.base
2180                .progress
2181                .stage_done("Loading VAE (GPU)", vae_stage.elapsed());
2182            vae
2183        };
2184        self.base.progress.stage_start("VAE decode");
2185        let vae_decode_start = Instant::now();
2186        let img_for_vae = img.to_dtype(vae_dtype)?;
2187        let device_for_sync = device.clone();
2188        let img = crate::vae_tiling::decode_with_oom_fallback(
2189            &img_for_vae,
2190            |latents| vae.decode(latents).map_err(Into::into),
2191            || {
2192                if let Err(e) = device_for_sync.synchronize() {
2193                    tracing::warn!(
2194                        "FLUX (sequential) device.synchronize() after VAE OOM failed: {e}"
2195                    );
2196                }
2197            },
2198        )?;
2199
2200        let img = ((img.clamp(-1f32, 1f32)? + 1.0)? * 127.5)?.to_dtype(DType::U8)?;
2201        let img = img.i(0)?;
2202
2203        self.base
2204            .progress
2205            .stage_done("VAE decode", vae_decode_start.elapsed());
2206        // VAE dropped here
2207
2208        let output_metadata = build_output_metadata(req, seed, None);
2209        let image_bytes = encode_image(
2210            &img,
2211            req.resolved_output_format(),
2212            req.width,
2213            req.height,
2214            output_metadata.as_ref(),
2215        )?;
2216
2217        let generation_time_ms = start.elapsed().as_millis() as u64;
2218        tracing::info!(generation_time_ms, seed, "sequential generation complete");
2219
2220        Ok(GenerateResponse {
2221            images: vec![ImageData {
2222                data: image_bytes,
2223                format: req.resolved_output_format(),
2224                width: req.width,
2225                height: req.height,
2226                index: 0,
2227            }],
2228            generation_time_ms,
2229            model: req.model.clone(),
2230            seed_used: seed,
2231            video: None,
2232            gpu: None,
2233        })
2234    }
2235}
2236
2237impl FluxEngine {
2238    fn defers_eager_load(&mut self) -> bool {
2239        self.base.load_strategy == LoadStrategy::Sequential
2240            || (self.offload && !self.detect_is_quantized())
2241    }
2242
2243    fn uses_sequential_generate_path(&mut self) -> bool {
2244        self.defers_eager_load()
2245    }
2246
2247    fn generate_inner(&mut self, req: &GenerateRequest) -> Result<GenerateResponse> {
2248        if req.scheduler.is_some() {
2249            tracing::warn!("scheduler selection not supported for FLUX (flow-matching), ignoring");
2250        }
2251
2252        // Sequential mode: load-use-drop each component. Forced FLUX offload
2253        // also routes here because block streaming is chosen after prompt
2254        // encoding; eager preload would put the full BF16 transformer on GPU
2255        // before offload can take effect.
2256        if self.uses_sequential_generate_path() {
2257            return self.generate_sequential(req);
2258        }
2259
2260        // Eager mode: use pre-loaded components
2261        // LoRA is supported — the transformer is rebuilt from disk on each generation
2262        // (dropped for VAE decode), so LoRA is applied during the rebuild via a
2263        // patched VarBuilder. No additional overhead compared to non-LoRA eager mode.
2264        // Borrow progress reporter separately from loaded state.
2265        let progress = &self.base.progress;
2266        let prompt_cache = &self.prompt_cache;
2267
2268        // Grab path references before borrowing loaded mutably
2269        let t5_encoder_path = self
2270            .base
2271            .loaded
2272            .as_ref()
2273            .map(|l| l.t5_encoder_path.clone())
2274            .or_else(|| self.base.paths.t5_encoder.clone())
2275            .ok_or_else(|| anyhow::anyhow!("T5 encoder path required for FLUX models"))?;
2276        let clip_encoder_path = self
2277            .base
2278            .paths
2279            .clip_encoder
2280            .clone()
2281            .ok_or_else(|| anyhow::anyhow!("CLIP encoder path required for FLUX models"))?;
2282        let transformer_path = self
2283            .base
2284            .loaded
2285            .as_ref()
2286            .map(|l| l.transformer_path.clone())
2287            .unwrap_or_else(|| self.base.paths.transformer.clone());
2288
2289        // Captured before we mutably borrow `self.base.loaded` via the
2290        // OptionRestoreGuard below — once that borrow is live, calling
2291        // `self.lora_delta_cache_handle()` would conflict.
2292        let cache_handle = self.lora_delta_cache_handle();
2293
2294        let mut loaded = OptionRestoreGuard::take(&mut self.base.loaded)
2295            .ok_or_else(|| anyhow::anyhow!("model not loaded — call load() first"))?;
2296
2297        let start = Instant::now();
2298        let seed = req.seed.unwrap_or_else(rand_seed);
2299
2300        let width = req.width as usize;
2301        let height = req.height as usize;
2302        let loaded_dtype = loaded.dtype;
2303        let loaded_device = loaded.device.clone();
2304
2305        tracing::info!(
2306            prompt = %req.prompt,
2307            seed,
2308            width,
2309            height,
2310            steps = req.steps,
2311            "starting generation"
2312        );
2313
2314        (|| -> Result<GenerateResponse> {
2315            // Only rebuild the transformer when the LoRA stack changes
2316            // (any adapter swap, scale change, add, remove, or reorder).
2317            let active_loras = effective_loras(req);
2318            let requested_stack = fingerprint_stack(&active_loras);
2319            if requested_stack != self.active_lora {
2320                if loaded.flux_model.is_some() {
2321                    loaded.flux_model = None;
2322                    loaded.device.synchronize()?;
2323                }
2324                self.active_lora = requested_stack;
2325            }
2326
2327            if loaded.flux_model.is_none() {
2328                let has_lora = !active_loras.is_empty();
2329                let xformer_label = match (loaded.is_quantized, has_lora) {
2330                    (true, true) => "Reloading FLUX transformer (GPU, quantized + LoRA)",
2331                    (true, false) => "Reloading FLUX transformer (GPU, quantized)",
2332                    (false, true) if loaded.dtype == DType::F16 => {
2333                        "Reloading FLUX transformer (GPU, FP16 + LoRA)"
2334                    }
2335                    (false, true) => "Reloading FLUX transformer (GPU, BF16 + LoRA)",
2336                    (false, false) if loaded.dtype == DType::F16 => {
2337                        "Reloading FLUX transformer (GPU, FP16)"
2338                    }
2339                    (false, false) => "Reloading FLUX transformer (GPU, BF16)",
2340                };
2341                progress.stage_start(xformer_label);
2342                let reload_start = Instant::now();
2343                let flux_cfg = if loaded.is_schnell {
2344                    flux::model::Config::schnell()
2345                } else {
2346                    flux::model::Config::dev()
2347                };
2348                let bypass_mode = LoraBypassMode::from_env();
2349                loaded.flux_model = Some(if loaded.is_quantized && has_lora {
2350                    // Quantized + LoRA stack. Bypass-mode (default `auto`)
2351                    // installs the LoRA at forward time on top of the
2352                    // quantized base — no dequant→merge→requant. Legacy
2353                    // fallback (MOLD_LORA_BYPASS=off) goes through
2354                    // `gguf_lora_var_builder`.
2355                    let bypass_quantized = bypass_mode != LoraBypassMode::Off;
2356                    if bypass_quantized {
2357                        let registry = build_lora_registry(
2358                            &active_loras,
2359                            &flux_cfg,
2360                            &loaded.device,
2361                            loaded.dtype,
2362                            progress,
2363                        )?;
2364                        let vb = quantized_var_builder::VarBuilder::from_gguf(
2365                            &transformer_path,
2366                            &loaded.device,
2367                        )?;
2368                        FluxTransformer::QuantizedBypass(
2369                            crate::flux::quantized_transformer::QuantizedFluxTransformer::load(
2370                                &flux_cfg,
2371                                vb,
2372                                registry.as_ref(),
2373                                progress,
2374                            )?,
2375                        )
2376                    } else {
2377                        let vb = flux_gguf_lora_var_builder(
2378                            &transformer_path,
2379                            &active_loras,
2380                            &loaded.device,
2381                            progress,
2382                            cache_handle.clone(),
2383                        )?;
2384                        FluxTransformer::Quantized(flux::quantized_model::Flux::new(&flux_cfg, vb)?)
2385                    }
2386                } else if loaded.is_quantized {
2387                    let vb = quantized_var_builder::VarBuilder::from_gguf(
2388                        &transformer_path,
2389                        &loaded.device,
2390                    )?;
2391                    FluxTransformer::Quantized(flux::quantized_model::Flux::new(&flux_cfg, vb)?)
2392                } else if has_lora {
2393                    // BF16 + LoRA stack: merge all deltas during construction
2394                    let flux_vb = flux_lora_var_builder(
2395                        &transformer_path,
2396                        &active_loras,
2397                        loaded.dtype,
2398                        &loaded.device,
2399                        progress,
2400                        cache_handle.clone(),
2401                    )?;
2402                    FluxTransformer::BF16(flux::model::Flux::new(&flux_cfg, flux_vb)?)
2403                } else {
2404                    let flux_vb = flux_transformer_var_builder(flux_safetensors_var_builder(
2405                        &transformer_path,
2406                        loaded.dtype,
2407                        &loaded.device,
2408                        "FLUX transformer",
2409                        progress,
2410                    )?);
2411                    FluxTransformer::BF16(flux::model::Flux::new(&flux_cfg, flux_vb)?)
2412                });
2413                progress.stage_done(xformer_label, reload_start.elapsed());
2414            }
2415
2416            if let Some((t5_emb, clip_emb)) = Self::restore_prompt_cache(
2417                progress,
2418                prompt_cache,
2419                &req.prompt,
2420                &loaded_device,
2421                loaded_dtype,
2422            )? {
2423                return Self::generate_with_embeddings(
2424                    progress,
2425                    req,
2426                    &mut loaded,
2427                    t5_emb,
2428                    clip_emb,
2429                    seed,
2430                    width,
2431                    height,
2432                    start,
2433                    self.base.gpu_ordinal,
2434                );
2435            }
2436
2437            if loaded.t5.model.is_none() {
2438                let label = if loaded.t5.is_parked() {
2439                    "Unparking T5 encoder (CPU→GPU)"
2440                } else {
2441                    "Reloading T5 encoder (GPU)"
2442                };
2443                progress.stage_start(label);
2444                let reload_start = Instant::now();
2445                if loaded.t5.is_parked() {
2446                    loaded.t5.unpark_to_gpu(loaded_dtype, progress)?;
2447                } else {
2448                    loaded.t5.reload(&t5_encoder_path, loaded_dtype, progress)?;
2449                }
2450                progress.stage_done(label, reload_start.elapsed());
2451            }
2452            if loaded.clip.model.is_none() {
2453                let label = if loaded.clip.is_parked() {
2454                    "Unparking CLIP encoder (CPU→GPU)"
2455                } else {
2456                    "Reloading CLIP encoder (GPU)"
2457                };
2458                progress.stage_start(label);
2459                let reload_start = Instant::now();
2460                if loaded.clip.is_parked() {
2461                    loaded.clip.unpark_to_gpu(loaded_dtype, progress)?;
2462                } else {
2463                    loaded
2464                        .clip
2465                        .reload(&clip_encoder_path, loaded_dtype, progress)?;
2466                }
2467                progress.stage_done(label, reload_start.elapsed());
2468            }
2469
2470            progress.stage_start("Encoding prompt (T5)");
2471            let encode_t5 = Instant::now();
2472            // Park to CPU between encode and denoise so the transformer
2473            // load + LoRA merge window (next ~200–500 ms) doesn't have to
2474            // budget for this tensor sitting on GPU. Idempotent on GGUF.
2475            let t5_emb = park_cond_to_cpu(&loaded.t5.encode(
2476                &req.prompt,
2477                &loaded_device,
2478                loaded_dtype,
2479            )?)?;
2480            progress.stage_done("Encoding prompt (T5)", encode_t5.elapsed());
2481            tracing::info!("T5 encoding complete");
2482
2483            progress.stage_start("Encoding prompt (CLIP)");
2484            let encode_clip = Instant::now();
2485            let clip_emb = park_cond_to_cpu(&loaded.clip.encode(
2486                &req.prompt,
2487                &loaded_device,
2488                loaded_dtype,
2489            )?)?;
2490            progress.stage_done("Encoding prompt (CLIP)", encode_clip.elapsed());
2491            tracing::info!("CLIP encoding complete");
2492            // CachedTensor::from_tensor already moves to CPU — passing CPU
2493            // tensors here avoids the round-trip on the GGUF path.
2494            Self::store_prompt_cache(prompt_cache, &req.prompt, &t5_emb, &clip_emb)?;
2495
2496            // Drop or park encoders to free GPU memory for denoising.
2497            //
2498            // Default (`MOLD_KEEP_TE_RAM=0`): drop weights from RAM too. The
2499            // next request re-mmaps from disk (~2-4 s for T5 Q8+LoRAs).
2500            //
2501            // Park mode (`MOLD_KEEP_TE_RAM=1`): move parameters to CPU host
2502            // RAM and drop the GPU copy. Next request only pays a CPU→GPU
2503            // tensor copy (~100-300 ms vs 2-4 s) — mirrors ComfyUI's
2504            // `text_encoder_offload_device()` behavior.
2505            //
2506            // On Metal (unified memory) parking is not a win since CPU and
2507            // GPU share the same physical pool, so we still drop there.
2508            let is_metal = loaded.device.is_metal();
2509            let park_mode = crate::device::keep_te_in_ram() && !is_metal;
2510            let mut dropped_gpu_encoder = false;
2511            if loaded.t5.on_gpu || is_metal {
2512                if loaded.t5.on_gpu {
2513                    dropped_gpu_encoder = true;
2514                }
2515                if park_mode {
2516                    loaded.t5.park_to_cpu()?;
2517                    tracing::info!(
2518                        on_gpu = loaded.t5.on_gpu,
2519                        "T5 encoder parked to CPU host RAM"
2520                    );
2521                } else {
2522                    loaded.t5.drop_weights();
2523                    tracing::info!(
2524                        on_gpu = loaded.t5.on_gpu,
2525                        "T5 encoder dropped to free memory for denoising"
2526                    );
2527                }
2528            }
2529            if loaded.clip.on_gpu || is_metal {
2530                if loaded.clip.on_gpu {
2531                    dropped_gpu_encoder = true;
2532                }
2533                if park_mode {
2534                    loaded.clip.park_to_cpu()?;
2535                    tracing::info!(
2536                        on_gpu = loaded.clip.on_gpu,
2537                        "CLIP encoder parked to CPU host RAM"
2538                    );
2539                } else {
2540                    loaded.clip.drop_weights();
2541                    tracing::info!(
2542                        on_gpu = loaded.clip.on_gpu,
2543                        "CLIP encoder dropped to free memory for denoising"
2544                    );
2545                }
2546            }
2547            // Force CUDA to complete the encoder cuMemFreeAsync before denoising
2548            // begins. Without this, the freed encoder VRAM (~5–6 GB for T5 Q8 +
2549            // CLIP) may not be available when the first denoising step allocates,
2550            // and on a tight 24 GB budget (Q8 transformer kept loaded + LoRAs)
2551            // that pushes VAE decode past the limit later in the pipeline.
2552            if dropped_gpu_encoder {
2553                loaded.device.synchronize()?;
2554            }
2555
2556            Self::generate_with_embeddings(
2557                progress,
2558                req,
2559                &mut loaded,
2560                t5_emb,
2561                clip_emb,
2562                seed,
2563                width,
2564                height,
2565                start,
2566                self.base.gpu_ordinal,
2567            )
2568        })()
2569    }
2570}
2571
2572impl InferenceEngine for FluxEngine {
2573    fn generate(&mut self, req: &GenerateRequest) -> Result<GenerateResponse> {
2574        self.pending_placement = req.placement.clone();
2575        let result = self.generate_inner(req);
2576        self.pending_placement = None;
2577        result
2578    }
2579
2580    fn model_name(&self) -> &str {
2581        self.base.model_name()
2582    }
2583
2584    fn is_loaded(&self) -> bool {
2585        // Sequential mode is always "ready" — it loads on demand
2586        self.base.is_loaded()
2587    }
2588
2589    fn load(&mut self) -> Result<()> {
2590        FluxEngine::load(self)
2591    }
2592
2593    fn unload(&mut self) {
2594        self.base.unload();
2595        // prompt_cache holds GPU-resident T5/CLIP embedding tensors; clear so
2596        // the unload actually frees VRAM.
2597        clear_cache(&self.prompt_cache);
2598        // active_lora reflects the LoRA currently merged into the loaded
2599        // transformer. After unload there is no transformer, so clear the
2600        // marker — the next reload re-applies whatever is in the request.
2601        self.active_lora = Vec::new();
2602        // lora_delta_cache lives on CPU and survives park so the next reload
2603        // can skip the B @ A · scale recompute. It dies with the engine on Drop.
2604    }
2605
2606    fn set_on_progress(&mut self, callback: ProgressCallback) {
2607        self.base.set_on_progress(callback);
2608    }
2609
2610    fn clear_on_progress(&mut self) {
2611        self.base.clear_on_progress();
2612    }
2613
2614    fn model_paths(&self) -> Option<&mold_core::ModelPaths> {
2615        Some(&self.base.paths)
2616    }
2617}
2618
2619impl FluxEngine {
2620    #[allow(clippy::too_many_arguments)]
2621    fn generate_with_embeddings(
2622        progress: &ProgressReporter,
2623        req: &GenerateRequest,
2624        loaded: &mut LoadedFlux,
2625        t5_emb: candle_core::Tensor,
2626        clip_emb: candle_core::Tensor,
2627        seed: u64,
2628        width: usize,
2629        height: usize,
2630        start: Instant,
2631        gpu_ordinal: usize,
2632    ) -> Result<GenerateResponse> {
2633        // 3. Generate initial noise (F32 for quantized, gpu_dtype for BF16)
2634        let noise_dtype = if loaded.is_quantized {
2635            DType::F32
2636        } else {
2637            loaded.dtype
2638        };
2639        let latent_h = height / 16 * 2;
2640        let latent_w = width / 16 * 2;
2641
2642        // Pre-compute timestep schedule (needed before mixing for img2img).
2643        let image_seq_len = (latent_h / 2) * (latent_w / 2);
2644        let mut timesteps = if loaded.is_schnell {
2645            flux::sampling::get_schedule(req.steps as usize, None)
2646        } else {
2647            flux::sampling::get_schedule(req.steps as usize, Some((image_seq_len, 0.5, 1.15)))
2648        };
2649
2650        if req.source_image.is_some() {
2651            let start_index = crate::img2img::img2img_start_index(req.steps as usize, req.strength);
2652            timesteps = timesteps[start_index..].to_vec();
2653            tracing::info!(
2654                strength = req.strength,
2655                start_index,
2656                start_timestep = timesteps[0],
2657                schedule = ?timesteps,
2658                remaining_steps = timesteps.len().saturating_sub(1),
2659                "img2img: truncated schedule from strength"
2660            );
2661        }
2662
2663        let (img, inpaint_ctx) = if let Some(ref source_bytes) = req.source_image {
2664            let start_t = timesteps[0];
2665
2666            progress.stage_start("Encoding source image (VAE)");
2667            let encode_start = Instant::now();
2668            let source_tensor = crate::img_utils::decode_source_image(
2669                source_bytes,
2670                req.width,
2671                req.height,
2672                crate::img_utils::NormalizeRange::MinusOneToOne,
2673                &loaded.device,
2674                loaded.vae_dtype,
2675            )?;
2676            let encoded = loaded.vae.encode(&source_tensor)?;
2677            progress.stage_done("Encoding source image (VAE)", encode_start.elapsed());
2678
2679            let noise = crate::engine::seeded_randn(
2680                seed,
2681                &[1, 16, latent_h, latent_w],
2682                &loaded.device,
2683                noise_dtype,
2684            )?;
2685            let encoded = encoded.to_dtype(noise_dtype)?;
2686
2687            let inpaint_ctx = if let Some(ref mask_bytes) = req.mask_image {
2688                let mask = crate::img_utils::decode_mask_image(
2689                    mask_bytes,
2690                    latent_h,
2691                    latent_w,
2692                    &loaded.device,
2693                    noise_dtype,
2694                )?;
2695                Some(crate::img_utils::InpaintContext {
2696                    original_latents: encoded.clone(),
2697                    mask,
2698                    noise: noise.clone(),
2699                })
2700            } else {
2701                None
2702            };
2703
2704            // latent = (1 - t) * encoded + t * noise
2705            let img = ((&encoded * (1.0 - start_t))? + (&noise * start_t)?)?;
2706            (img, inpaint_ctx)
2707        } else {
2708            let img = crate::engine::seeded_randn(
2709                seed,
2710                &[1, 16, latent_h, latent_w],
2711                &loaded.device,
2712                noise_dtype,
2713            )?;
2714            (img, None)
2715        };
2716
2717        // Migrate parked conditioning tensors back to GPU now that the
2718        // transformer load + LoRA merge phase is over. `to_device` on a
2719        // tensor already on `loaded.device` is a no-op clone, so the
2720        // cache-restore path costs nothing here.
2721        let t5_emb = t5_emb.to_device(&loaded.device)?;
2722        let clip_emb = clip_emb.to_device(&loaded.device)?;
2723        // For quantized model, state tensors must be F32
2724        let (t5_emb_state, clip_emb_state, img_state) = if loaded.is_quantized {
2725            (
2726                t5_emb.to_dtype(DType::F32)?,
2727                clip_emb.to_dtype(DType::F32)?,
2728                img.to_dtype(DType::F32)?,
2729            )
2730        } else {
2731            (t5_emb, clip_emb, img)
2732        };
2733
2734        // Build sampling state
2735        let state = flux::sampling::State::new(&t5_emb_state, &clip_emb_state, &img_state)?;
2736        let inpaint_ctx = inpaint_ctx
2737            .as_ref()
2738            .map(crate::img2img::pack_flux_inpaint_context)
2739            .transpose()?;
2740
2741        let denoise_label = format!("Denoising ({} steps)", timesteps.len().saturating_sub(1));
2742        progress.stage_start(&denoise_label);
2743        let denoise_start = Instant::now();
2744        tracing::info!(
2745            steps = timesteps.len().saturating_sub(1),
2746            quantized = loaded.is_quantized,
2747            "running denoising loop..."
2748        );
2749
2750        // Denoise — guidance from request (0.0 for schnell, 3.5+ for dev/finetuned)
2751        let img = loaded
2752            .flux_model
2753            .as_ref()
2754            .ok_or_else(|| anyhow::anyhow!("transformer not loaded"))?
2755            .denoise(
2756                &state.img,
2757                &state.img_ids,
2758                &state.txt,
2759                &state.txt_ids,
2760                &state.vec,
2761                &timesteps,
2762                req.guidance,
2763                progress,
2764                inpaint_ctx.as_ref(),
2765            )?;
2766
2767        // 7. Unpack latent to spatial
2768        let img = flux::sampling::unpack(&img, height, width)?;
2769        progress.stage_done(&denoise_label, denoise_start.elapsed());
2770        tracing::info!("denoising complete, decoding VAE...");
2771
2772        // Free denoising intermediates and transformer before VAE decode.
2773        // On discrete GPUs (CUDA), the BF16 transformer alone is ~24GB — VAE
2774        // decode needs that VRAM for conv2d intermediates. For Q8 (~12GB) on a
2775        // 24GB GPU, the transformer can stay resident; dropping forces a full
2776        // `gguf_lora_var_builder` rebuild on the next generation, which peaks
2777        // at ~95GB CPU when LoRAs are applied. `MOLD_FLUX_KEEP_TRANSFORMER=1`
2778        // opts into keeping it loaded across same-LoRA generations.
2779        drop(state);
2780        drop(t5_emb_state);
2781        drop(clip_emb_state);
2782        drop(img_state);
2783        let keep_transformer_env = std::env::var("MOLD_FLUX_KEEP_TRANSFORMER")
2784            .map(|v| v == "1")
2785            .unwrap_or(false);
2786
2787        // Even with KEEP_TRANSFORMER=1 the keep is conditional: VAE decode
2788        // needs a large contiguous conv2d allocation (~2–3 GB peak at 1024²,
2789        // ~10–12 GB at 2048²). When the kept transformer + LoRA-merged
2790        // tensors leave too little headroom (observed at ~3 GB free with a
2791        // 2-LoRA stack on a 24 GB card), the VAE alloc OOMs even though the
2792        // resident transformer size is identical to the no-LoRA case. The
2793        // next request rebuilds — that's the trade-off for not OOMing here.
2794        //
2795        // The headroom budget scales with output resolution via
2796        // [`activation_bytes`] instead of a fixed 5 GB magic — at 1024² the
2797        // budget is the FluxDit floor (~256 MB, the previous 5 GB was wildly
2798        // over-conservative on a busy 24 GB card with KEEP_TRANSFORMER=1)
2799        // while at 2048² it grows past 1 GB, catching what fixed 5 GB only
2800        // approximated.
2801        let vae_headroom_bytes = crate::device::activation_bytes(
2802            req.width,
2803            req.height,
2804            1,
2805            crate::device::dtype_bytes(loaded.dtype),
2806            crate::device::ActivationFamily::FluxDit,
2807        );
2808        let free_before_vae = crate::device::free_vram_bytes(gpu_ordinal).unwrap_or(0);
2809        let force_drop_for_headroom =
2810            keep_transformer_env && free_before_vae > 0 && free_before_vae < vae_headroom_bytes;
2811
2812        if !keep_transformer_env || force_drop_for_headroom {
2813            loaded.flux_model = None;
2814            if force_drop_for_headroom {
2815                tracing::info!(
2816                    free_mb = free_before_vae / 1024 / 1024,
2817                    headroom_mb = vae_headroom_bytes / 1024 / 1024,
2818                    "Transformer force-dropped before VAE decode (free VRAM below \
2819                     resolution-scaled headroom; overrides MOLD_FLUX_KEEP_TRANSFORMER=1 \
2820                     for this request)"
2821                );
2822            } else {
2823                tracing::info!("Transformer dropped to free VRAM for VAE decode");
2824            }
2825        } else {
2826            tracing::info!(
2827                free_mb = free_before_vae / 1024 / 1024,
2828                "Transformer kept loaded (MOLD_FLUX_KEEP_TRANSFORMER=1)"
2829            );
2830        }
2831        // Force CUDA to complete pending operations and release freed memory
2832        // before VAE decode allocates its conv2d intermediates. cuMemFree is
2833        // asynchronous, so the drops above (denoising state + embeddings, plus
2834        // the optional transformer drop) may not have actually returned VRAM
2835        // to the allocator yet. Without this synchronize, VAE decode at 1024²
2836        // OOMs on the first conv allocation — observable on the keep-transformer
2837        // path even on iteration 1 (the drop path used to synchronize here, the
2838        // keep path didn't, which made the bug branch-specific).
2839        loaded.device.synchronize()?;
2840
2841        // 8. Decode with VAE — cast to the VAE's actual loaded dtype (which
2842        // may differ from `loaded.dtype` when MOLD_VAE_DTYPE forces fp32 to
2843        // suppress banding; the quantized-model F32-state case is also
2844        // handled by this cast).
2845        progress.stage_start("VAE decode");
2846        let vae_decode_start = Instant::now();
2847        let img_for_vae = img.to_dtype(loaded.vae_dtype)?;
2848        let vae = &loaded.vae;
2849        let device_for_sync = loaded.device.clone();
2850        let img = crate::vae_tiling::decode_with_oom_fallback(
2851            &img_for_vae,
2852            |latents| vae.decode(latents).map_err(Into::into),
2853            || {
2854                if let Err(e) = device_for_sync.synchronize() {
2855                    tracing::warn!(
2856                        "FLUX (parallel) device.synchronize() after VAE OOM failed: {e}"
2857                    );
2858                }
2859            },
2860        )?;
2861
2862        // 9. Convert to u8 image: clamp to [-1, 1], map to [0, 255]
2863        let img = ((img.clamp(-1f32, 1f32)? + 1.0)? * 127.5)?.to_dtype(DType::U8)?;
2864        let img = img.i(0)?; // remove batch dim: [3, H, W]
2865
2866        progress.stage_done("VAE decode", vae_decode_start.elapsed());
2867        tracing::info!("VAE decode complete, encoding output image...");
2868
2869        // 10. Convert candle tensor to image bytes
2870        let output_metadata = build_output_metadata(req, seed, None);
2871        let image_bytes = encode_image(
2872            &img,
2873            req.resolved_output_format(),
2874            req.width,
2875            req.height,
2876            output_metadata.as_ref(),
2877        )?;
2878
2879        let generation_time_ms = start.elapsed().as_millis() as u64;
2880        tracing::info!(generation_time_ms, seed, "generation complete");
2881
2882        Ok(GenerateResponse {
2883            images: vec![ImageData {
2884                data: image_bytes,
2885                format: req.resolved_output_format(),
2886                width: req.width,
2887                height: req.height,
2888                index: 0,
2889            }],
2890            generation_time_ms,
2891            model: req.model.clone(),
2892            seed_used: seed,
2893            video: None,
2894            gpu: None,
2895        })
2896    }
2897}
2898
2899#[cfg(test)]
2900mod tests {
2901    use super::{
2902        effective_loras, flux_rms_norm_scale_aliases, flux_runtime_dtype,
2903        flux_transformer_var_builder, park_cond_to_cpu, LoraBypassMode,
2904    };
2905    use crate::LoadStrategy;
2906    use candle_core::{DType, Device, Result, Tensor};
2907    use candle_nn::VarBuilder;
2908    use mold_core::{GenerateRequest, LoraWeight, ModelPaths, OutputFormat};
2909    use std::collections::HashMap;
2910    use std::path::PathBuf;
2911
2912    /// `MOLD_LORA_BYPASS=on` and `=off` are the two boundaries we
2913    /// document. Any other value (including unset) must collapse to
2914    /// `Auto` so we never silently change behaviour because of
2915    /// stale `MOLD_*` env vars in a developer's shell.
2916    #[test]
2917    fn lora_bypass_mode_env_parsing() {
2918        let with_env = |val: Option<&str>| -> LoraBypassMode {
2919            unsafe {
2920                match val {
2921                    Some(v) => std::env::set_var("MOLD_LORA_BYPASS", v),
2922                    None => std::env::remove_var("MOLD_LORA_BYPASS"),
2923                }
2924            }
2925            let mode = LoraBypassMode::from_env();
2926            unsafe {
2927                std::env::remove_var("MOLD_LORA_BYPASS");
2928            }
2929            mode
2930        };
2931        assert_eq!(with_env(Some("on")), LoraBypassMode::On);
2932        assert_eq!(with_env(Some("ON")), LoraBypassMode::On);
2933        assert_eq!(with_env(Some("1")), LoraBypassMode::On);
2934        assert_eq!(with_env(Some("off")), LoraBypassMode::Off);
2935        assert_eq!(with_env(Some("0")), LoraBypassMode::Off);
2936        assert_eq!(with_env(Some("auto")), LoraBypassMode::Auto);
2937        assert_eq!(with_env(Some("garbage")), LoraBypassMode::Auto);
2938        assert_eq!(with_env(None), LoraBypassMode::Auto);
2939    }
2940
2941    #[test]
2942    fn flux_rms_norm_aliases_detect_weight_suffix_checkpoint() {
2943        use safetensors::tensor::{serialize_to_file, Dtype as SafeDtype, TensorView};
2944
2945        let dir = std::env::temp_dir().join(format!(
2946            "mold-flux-rms-alias-{}-{}",
2947            std::process::id(),
2948            std::time::SystemTime::now()
2949                .duration_since(std::time::UNIX_EPOCH)
2950                .unwrap()
2951                .as_nanos()
2952        ));
2953        std::fs::create_dir_all(&dir).unwrap();
2954        let path = dir.join("flux-rms-weight.safetensors");
2955
2956        let data = 1.0f32.to_le_bytes();
2957        let mut tensors = HashMap::new();
2958        tensors.insert(
2959            "model.diffusion_model.double_blocks.0.img_attn.norm.query_norm.weight".to_string(),
2960            TensorView::new(SafeDtype::F32, vec![1], &data).unwrap(),
2961        );
2962        serialize_to_file(&tensors, &None, &path).unwrap();
2963
2964        let aliases = flux_rms_norm_scale_aliases(&path).unwrap();
2965        assert_eq!(
2966            aliases.get("model.diffusion_model.double_blocks.0.img_attn.norm.query_norm.scale"),
2967            Some(
2968                &"model.diffusion_model.double_blocks.0.img_attn.norm.query_norm.weight"
2969                    .to_string()
2970            )
2971        );
2972
2973        std::fs::remove_dir_all(&dir).ok();
2974    }
2975
2976    fn dummy_paths(transformer: &str) -> ModelPaths {
2977        ModelPaths {
2978            transformer: PathBuf::from(transformer),
2979            transformer_shards: Vec::new(),
2980            vae: PathBuf::from("ae.safetensors"),
2981            spatial_upscaler: None,
2982            temporal_upscaler: None,
2983            distilled_lora: None,
2984            t5_encoder: Some(PathBuf::from("t5.safetensors")),
2985            clip_encoder: Some(PathBuf::from("clip.safetensors")),
2986            t5_tokenizer: Some(PathBuf::from("t5-tokenizer.json")),
2987            clip_tokenizer: Some(PathBuf::from("clip-tokenizer.json")),
2988            clip_encoder_2: None,
2989            clip_tokenizer_2: None,
2990            text_encoder_files: Vec::new(),
2991            text_tokenizer: None,
2992            decoder: None,
2993        }
2994    }
2995
2996    #[test]
2997    fn forced_offload_uses_sequential_generation_path_for_bf16_flux() {
2998        let mut engine = super::FluxEngine::new(
2999            "flux-dev:bf16".to_string(),
3000            dummy_paths("flux1-dev.safetensors"),
3001            Some(false),
3002            None,
3003            LoadStrategy::Eager,
3004            0,
3005            true,
3006            None,
3007        );
3008
3009        assert!(engine.uses_sequential_generate_path());
3010    }
3011
3012    #[test]
3013    fn forced_offload_defers_eager_load_for_bf16_flux() {
3014        let mut engine = super::FluxEngine::new(
3015            "flux-dev:bf16".to_string(),
3016            dummy_paths("flux1-dev.safetensors"),
3017            Some(false),
3018            None,
3019            LoadStrategy::Eager,
3020            0,
3021            true,
3022            None,
3023        );
3024
3025        assert!(engine.defers_eager_load());
3026    }
3027
3028    /// Minimal `GenerateRequest` carrying only the fields `effective_loras`
3029    /// touches (`lora`, `loras`). Every other field is set to a benign
3030    /// default so the tests don't drift when unrelated request shapes
3031    /// change.
3032    fn req_with_loras(
3033        single: Option<LoraWeight>,
3034        plural: Option<Vec<LoraWeight>>,
3035    ) -> GenerateRequest {
3036        GenerateRequest {
3037            prompt: String::new(),
3038            negative_prompt: None,
3039            model: "flux-dev".to_string(),
3040            width: 1024,
3041            height: 1024,
3042            steps: 4,
3043            guidance: 0.0,
3044            seed: None,
3045            batch_size: 1,
3046            output_format: Some(OutputFormat::Png),
3047            embed_metadata: None,
3048            scheduler: None,
3049            cfg_plus: None,
3050            source_image: None,
3051            edit_images: None,
3052            strength: 0.75,
3053            mask_image: None,
3054            control_image: None,
3055            control_model: None,
3056            control_scale: 1.0,
3057            expand: None,
3058            original_prompt: None,
3059            lora: single,
3060            frames: None,
3061            fps: None,
3062            upscale_model: None,
3063            gif_preview: false,
3064            enable_audio: None,
3065            audio_file: None,
3066            audio_file_path: None,
3067            source_video: None,
3068            source_video_path: None,
3069            keyframes: None,
3070            pipeline: None,
3071            loras: plural,
3072            retake_range: None,
3073            spatial_upscale: None,
3074            temporal_upscale: None,
3075            placement: None,
3076        }
3077    }
3078
3079    /// A slider scrubbed to zero on one of three stacked LoRAs must
3080    /// drop ONLY that entry from the effective stack.
3081    #[test]
3082    fn effective_loras_drops_zero_scale() {
3083        let req = req_with_loras(
3084            None,
3085            Some(vec![
3086                LoraWeight {
3087                    path: "p1".into(),
3088                    scale: 0.8,
3089                },
3090                LoraWeight {
3091                    path: "p2".into(),
3092                    scale: 0.0,
3093                },
3094                LoraWeight {
3095                    path: "p3".into(),
3096                    scale: 0.5,
3097                },
3098            ]),
3099        );
3100        let stack = effective_loras(&req);
3101        let paths: Vec<&str> = stack.iter().map(|w| w.path.as_str()).collect();
3102        assert_eq!(
3103            paths,
3104            vec!["p1", "p3"],
3105            "p2 (scale=0.0) must be dropped from the effective stack"
3106        );
3107        assert!((stack[0].scale - 0.8).abs() < 1e-9);
3108        assert!((stack[1].scale - 0.5).abs() < 1e-9);
3109    }
3110
3111    /// Negative scales are a legitimate "anti-style" use case.
3112    #[test]
3113    fn effective_loras_keeps_negative_scales() {
3114        let req = req_with_loras(
3115            None,
3116            Some(vec![LoraWeight {
3117                path: "p1".into(),
3118                scale: -0.3,
3119            }]),
3120        );
3121        let stack = effective_loras(&req);
3122        assert_eq!(stack.len(), 1);
3123        assert!((stack[0].scale - (-0.3)).abs() < 1e-9);
3124    }
3125
3126    /// Single `lora` form: `scale: 0.0` should be dropped too.
3127    #[test]
3128    fn effective_loras_drops_zero_scale_on_single_form() {
3129        let req = req_with_loras(
3130            Some(LoraWeight {
3131                path: "p1".into(),
3132                scale: 0.0,
3133            }),
3134            None,
3135        );
3136        assert!(effective_loras(&req).is_empty());
3137    }
3138
3139    /// Idempotency: a tensor already on CPU comes back on CPU and equals the
3140    /// input — `park_cond_to_cpu` must not pay for a redundant copy on the
3141    /// GGUF / Q8 path where T5 already produces CPU tensors.
3142    #[test]
3143    fn park_cond_to_cpu_is_idempotent_for_cpu_tensors() {
3144        let cpu_tensor = Tensor::zeros((2, 4), DType::F32, &Device::Cpu).unwrap();
3145        let parked = park_cond_to_cpu(&cpu_tensor).unwrap();
3146        assert!(parked.device().is_cpu(), "CPU input must stay on CPU");
3147        assert_eq!(parked.shape(), cpu_tensor.shape());
3148    }
3149
3150    /// `park_cond_to_cpu` output is on CPU regardless of input device.
3151    #[test]
3152    fn park_cond_to_cpu_returns_cpu_tensor_for_any_input() {
3153        let input = Tensor::ones((1, 3), DType::F32, &Device::Cpu).unwrap();
3154        let parked = park_cond_to_cpu(&input).unwrap();
3155        assert!(parked.device().is_cpu(), "output must be on CPU");
3156        assert_eq!(parked.shape(), input.shape());
3157        assert_eq!(parked.dtype(), input.dtype());
3158    }
3159
3160    #[test]
3161    fn flux_var_builder_uses_root_tensors_when_present() -> Result<()> {
3162        let tensors = HashMap::from([(
3163            "img_in.weight".to_string(),
3164            Tensor::zeros((1, 1), DType::F32, &Device::Cpu)?,
3165        )]);
3166        let vb = VarBuilder::from_tensors(tensors, DType::F32, &Device::Cpu);
3167        let resolved = flux_transformer_var_builder(vb);
3168
3169        assert!(resolved.contains_tensor("img_in.weight"));
3170        assert_eq!(resolved.prefix(), "");
3171        Ok(())
3172    }
3173
3174    #[test]
3175    fn flux_var_builder_uses_model_diffusion_model_prefix_when_present() -> Result<()> {
3176        let tensors = HashMap::from([(
3177            "model.diffusion_model.img_in.weight".to_string(),
3178            Tensor::zeros((1, 1), DType::F32, &Device::Cpu)?,
3179        )]);
3180        let vb = VarBuilder::from_tensors(tensors, DType::F32, &Device::Cpu);
3181        let resolved = flux_transformer_var_builder(vb);
3182
3183        assert!(resolved.contains_tensor("img_in.weight"));
3184        assert_eq!(resolved.prefix(), "model.diffusion_model");
3185        Ok(())
3186    }
3187
3188    #[test]
3189    fn flux_runtime_dtype_prefers_f16_for_cuda_fp8_safetensors() {
3190        assert_eq!(flux_runtime_dtype(true, false, true), DType::F16);
3191        assert_eq!(flux_runtime_dtype(true, false, false), DType::BF16);
3192        assert_eq!(flux_runtime_dtype(false, false, true), DType::F32);
3193    }
3194
3195    #[test]
3196    fn flux_runtime_dtype_quantized_matches_gpu_policy() {
3197        assert_eq!(flux_runtime_dtype(true, true, false), DType::BF16);
3198        assert_eq!(flux_runtime_dtype(false, true, false), DType::F32);
3199        assert_eq!(flux_runtime_dtype(true, true, true), DType::BF16);
3200        assert_eq!(flux_runtime_dtype(false, true, true), DType::F32);
3201    }
3202
3203    #[test]
3204    fn fp8_cache_path_includes_file_size() {
3205        // Create a temp file with known size to test cache path generation
3206        let dir = std::env::temp_dir().join(format!("mold-cache-test-{}", std::process::id()));
3207        std::fs::create_dir_all(&dir).unwrap();
3208        let fp8_file = dir.join("transformer.safetensors");
3209        std::fs::write(&fp8_file, vec![0u8; 1024]).unwrap();
3210
3211        let cache_path = super::fp8_gguf_cache_path(&fp8_file);
3212        let filename = cache_path.file_name().unwrap().to_str().unwrap();
3213
3214        // Should contain the file stem and the size
3215        assert!(
3216            filename.contains("transformer"),
3217            "should contain stem: {filename}"
3218        );
3219        assert!(
3220            filename.contains("1024"),
3221            "should contain file size: {filename}"
3222        );
3223        assert!(
3224            filename.ends_with(".q8_0.gguf"),
3225            "should end with .q8_0.gguf: {filename}"
3226        );
3227
3228        // Different size → different cache path
3229        std::fs::write(&fp8_file, vec![0u8; 2048]).unwrap();
3230        let cache_path2 = super::fp8_gguf_cache_path(&fp8_file);
3231        assert_ne!(
3232            cache_path, cache_path2,
3233            "different file sizes should produce different cache paths"
3234        );
3235
3236        std::fs::remove_dir_all(&dir).ok();
3237    }
3238
3239    #[test]
3240    fn fp8_q8_cache_quantizes_only_block_aligned_last_dim() {
3241        assert!(super::q8_0_can_quantize_dims(&[3072, 3072]));
3242        assert!(super::q8_0_can_quantize_dims(&[1, 64]));
3243        assert!(
3244            !super::q8_0_can_quantize_dims(&[256, 256, 3, 3]),
3245            "conv kernels have total elements divisible by 32, but Q8_0 \
3246             requires the last dimension itself to be block-aligned"
3247        );
3248        assert!(!super::q8_0_can_quantize_dims(&[512, 512, 1, 1]));
3249        assert!(!super::q8_0_can_quantize_dims(&[3072]));
3250    }
3251
3252    #[test]
3253    fn fp8_q8_cache_skips_bundled_text_encoder_and_scalar_tensors() {
3254        assert!(super::fp8_cache_should_skip_tensor(
3255            "text_encoders.clip_l.logit_scale",
3256            &[]
3257        ));
3258        assert!(super::fp8_cache_should_skip_tensor(
3259            "text_encoders.t5xxl.encoder.block.0.layer.0.SelfAttention.q.weight",
3260            &[4096, 4096]
3261        ));
3262        assert!(super::fp8_cache_should_skip_tensor("some.scalar", &[]));
3263        assert!(!super::fp8_cache_should_skip_tensor(
3264            "double_blocks.0.img_attn.qkv.weight",
3265            &[9216, 3072]
3266        ));
3267    }
3268
3269    #[test]
3270    fn fp8_cache_path_lives_under_cache_flux_q8() {
3271        let path = std::path::Path::new("/some/model/my-model.safetensors");
3272        // File doesn't exist so size will be 0
3273        let cache_path = super::fp8_gguf_cache_path(path);
3274        let cache_str = cache_path.to_str().unwrap();
3275        assert!(
3276            cache_str.contains("cache/flux-q8"),
3277            "cache should be under cache/flux-q8: {cache_str}"
3278        );
3279    }
3280
3281    #[test]
3282    fn fp8_cache_temp_paths_are_unique_per_writer() {
3283        let cache_path =
3284            std::path::Path::new("/tmp/agfluxSchnell_realistic23-1234-deadbeef.q8_0.gguf");
3285
3286        let first = super::fp8_gguf_tmp_path(cache_path);
3287        let second = super::fp8_gguf_tmp_path(cache_path);
3288
3289        assert_ne!(first, second);
3290        assert_ne!(first, cache_path);
3291        assert_ne!(second, cache_path);
3292    }
3293
3294    #[test]
3295    fn detects_schnell_from_uppercase_filename() {
3296        let engine = super::FluxEngine::new(
3297            "cv:1153358".to_string(),
3298            dummy_paths("agfluxSchnell_realistic23.safetensors"),
3299            None,
3300            None,
3301            LoadStrategy::Sequential,
3302            0,
3303            false,
3304            None,
3305        );
3306
3307        assert!(engine.detect_is_schnell());
3308    }
3309
3310    #[test]
3311    fn flux_vae_var_builder_accepts_vae_prefix() {
3312        use safetensors::tensor::{serialize_to_file, Dtype as SafeDtype, TensorView};
3313        use std::collections::HashMap;
3314
3315        let dir = std::env::temp_dir().join(format!(
3316            "mold-flux-vae-prefix-{}-{}",
3317            std::process::id(),
3318            std::time::SystemTime::now()
3319                .duration_since(std::time::UNIX_EPOCH)
3320                .unwrap()
3321                .as_nanos()
3322        ));
3323        std::fs::create_dir_all(&dir).unwrap();
3324        let path = dir.join("vae-prefix.safetensors");
3325
3326        let data = vec![0u8; 128 * 3 * 3 * 3 * std::mem::size_of::<f32>()];
3327        let shape = vec![128, 3, 3, 3];
3328        let view = TensorView::new(SafeDtype::F32, shape, &data).unwrap();
3329        let mut tensors = HashMap::new();
3330        tensors.insert("vae.encoder.conv_in.weight".to_string(), view);
3331        serialize_to_file(&tensors, &None, &path).unwrap();
3332
3333        let vb = crate::weight_loader::load_safetensors_with_progress(
3334            std::slice::from_ref(&path),
3335            DType::F32,
3336            &Device::Cpu,
3337            "test VAE",
3338            &crate::progress::ProgressReporter::default(),
3339        )
3340        .unwrap();
3341        let vb = super::flux_vae_var_builder(vb);
3342
3343        assert!(vb.contains_tensor("encoder.conv_in.weight"));
3344
3345        std::fs::remove_dir_all(&dir).ok();
3346    }
3347
3348    // ── Embedding patching tests ────────────────────────────────────────
3349
3350    /// Helper: write a minimal GGUF file containing the given tensor names.
3351    /// Each tensor is a tiny 1-element F32 QTensor.
3352    fn write_test_gguf(path: &std::path::Path, tensor_names: &[&str]) {
3353        let device = Device::Cpu;
3354        let qtensors: Vec<(String, candle_core::quantized::QTensor)> = tensor_names
3355            .iter()
3356            .map(|name| {
3357                let t = Tensor::zeros(1, DType::F32, &device).unwrap();
3358                let qt = candle_core::quantized::QTensor::quantize(
3359                    &t,
3360                    candle_core::quantized::GgmlDType::F32,
3361                )
3362                .unwrap();
3363                (name.to_string(), qt)
3364            })
3365            .collect();
3366        let refs: Vec<(&str, &candle_core::quantized::QTensor)> =
3367            qtensors.iter().map(|(n, q)| (n.as_str(), q)).collect();
3368        let file = std::fs::File::create(path).unwrap();
3369        let mut writer = std::io::BufWriter::new(file);
3370        candle_core::quantized::gguf_file::write(&mut writer, &[], &refs).unwrap();
3371    }
3372
3373    #[test]
3374    fn gguf_has_embeddings_true_for_complete() {
3375        let dir =
3376            std::env::temp_dir().join(format!("mold-emb-test-complete-{}", std::process::id()));
3377        std::fs::create_dir_all(&dir).unwrap();
3378        let path = dir.join("complete.gguf");
3379        write_test_gguf(
3380            &path,
3381            &[
3382                "img_in.weight",
3383                "img_in.bias",
3384                "double_blocks.0.img_mod.lin.weight",
3385            ],
3386        );
3387        assert!(super::gguf_has_embeddings(&path).unwrap());
3388        std::fs::remove_dir_all(&dir).ok();
3389    }
3390
3391    #[test]
3392    fn gguf_has_embeddings_false_for_incomplete() {
3393        let dir =
3394            std::env::temp_dir().join(format!("mold-emb-test-incomplete-{}", std::process::id()));
3395        std::fs::create_dir_all(&dir).unwrap();
3396        let path = dir.join("incomplete.gguf");
3397        write_test_gguf(
3398            &path,
3399            &[
3400                "double_blocks.0.img_mod.lin.weight",
3401                "single_blocks.0.linear1.weight",
3402                "txt_in.weight",
3403            ],
3404        );
3405        assert!(!super::gguf_has_embeddings(&path).unwrap());
3406        std::fs::remove_dir_all(&dir).ok();
3407    }
3408
3409    #[test]
3410    fn embedding_patched_cache_path_format() {
3411        let dir = std::env::temp_dir().join(format!("mold-emb-cache-fmt-{}", std::process::id()));
3412        std::fs::create_dir_all(&dir).unwrap();
3413        let gguf_file = dir.join("ultrareal.gguf");
3414        std::fs::write(&gguf_file, vec![0u8; 512]).unwrap();
3415
3416        let cache_path = super::embedding_patched_cache_path(&gguf_file);
3417        let cache_str = cache_path.to_str().unwrap();
3418        assert!(
3419            cache_str.contains("cache/flux-embeddings"),
3420            "should be under cache/flux-embeddings: {cache_str}"
3421        );
3422        let filename = cache_path.file_name().unwrap().to_str().unwrap();
3423        assert!(
3424            filename.contains("ultrareal"),
3425            "should contain stem: {filename}"
3426        );
3427        assert!(
3428            filename.contains("512"),
3429            "should contain file size: {filename}"
3430        );
3431        assert!(
3432            filename.ends_with(".patched.gguf"),
3433            "should end with .patched.gguf: {filename}"
3434        );
3435
3436        std::fs::remove_dir_all(&dir).ok();
3437    }
3438
3439    #[test]
3440    fn ensure_gguf_embeddings_noop_for_complete() {
3441        let dir = std::env::temp_dir().join(format!("mold-emb-noop-{}", std::process::id()));
3442        std::fs::create_dir_all(&dir).unwrap();
3443        let path = dir.join("complete.gguf");
3444
3445        // Write a GGUF with img_in.weight present
3446        write_test_gguf(
3447            &path,
3448            &["img_in.weight", "double_blocks.0.img_mod.lin.weight"],
3449        );
3450
3451        let progress = crate::progress::ProgressReporter::default();
3452        let result = super::ensure_gguf_embeddings(&path, false, &progress, None).unwrap();
3453
3454        // Should return the original path unchanged
3455        assert_eq!(result, path);
3456
3457        std::fs::remove_dir_all(&dir).ok();
3458    }
3459
3460    #[test]
3461    fn ensure_gguf_embeddings_patches_incomplete_with_reference() {
3462        // Test the full patching flow using a synthetic reference GGUF.
3463        // Uses models_dir_override to avoid mutating process-global env vars.
3464        let dir = std::env::temp_dir().join(format!("mold-emb-patch-{}", std::process::id()));
3465        std::fs::create_dir_all(&dir).unwrap();
3466
3467        // Create an incomplete GGUF (city96 format — only diffusion blocks)
3468        let incomplete_path = dir.join("ultrareal-test.gguf");
3469        write_test_gguf(
3470            &incomplete_path,
3471            &[
3472                "double_blocks.0.img_mod.lin.weight",
3473                "single_blocks.0.linear1.weight",
3474                "txt_in.weight",
3475                "txt_in.bias",
3476                "final_layer.linear.weight",
3477            ],
3478        );
3479
3480        // Create a fake reference model at the expected manifest path.
3481        // flux-dev:q8 transformer lives at <models_dir>/flux-dev-q8/flux1-dev-Q8_0.gguf
3482        let models_dir = dir.join("models");
3483        let ref_model_dir = models_dir.join("flux-dev-q8");
3484        std::fs::create_dir_all(&ref_model_dir).unwrap();
3485        let ref_path = ref_model_dir.join("flux1-dev-Q8_0.gguf");
3486
3487        // The reference GGUF has all embedding tensors
3488        let mut all_tensors: Vec<&str> = super::FLUX_EMBEDDING_TENSORS.to_vec();
3489        all_tensors.extend_from_slice(super::FLUX_GUIDANCE_EMBEDDING_TENSORS);
3490        all_tensors.extend_from_slice(&[
3491            "double_blocks.0.img_mod.lin.weight",
3492            "txt_in.weight",
3493            "txt_in.bias",
3494        ]);
3495        write_test_gguf(&ref_path, &all_tensors);
3496
3497        let progress = crate::progress::ProgressReporter::default();
3498        let result =
3499            super::ensure_gguf_embeddings(&incomplete_path, false, &progress, Some(&models_dir));
3500
3501        let patched_path = result.unwrap();
3502        assert_ne!(
3503            patched_path, incomplete_path,
3504            "should return a different cached path"
3505        );
3506        assert!(patched_path.exists(), "patched GGUF should exist on disk");
3507        assert!(
3508            patched_path.to_str().unwrap().contains("flux-embeddings"),
3509            "patched file should be in flux-embeddings cache"
3510        );
3511
3512        // Verify the patched file contains the embedding tensors
3513        assert!(
3514            super::gguf_has_embeddings(&patched_path).unwrap(),
3515            "patched GGUF should have embeddings"
3516        );
3517
3518        // Clean up
3519        std::fs::remove_dir_all(&dir).ok();
3520        std::fs::remove_file(&patched_path).ok();
3521        let _ = std::fs::remove_dir(patched_path.parent().unwrap());
3522    }
3523
3524    #[test]
3525    fn ensure_gguf_embeddings_cache_is_reused() {
3526        // If a cache file already exists, it should be returned directly
3527        let dir = std::env::temp_dir().join(format!("mold-emb-reuse-{}", std::process::id()));
3528        std::fs::create_dir_all(&dir).unwrap();
3529
3530        let incomplete_path = dir.join("model-for-cache.gguf");
3531        write_test_gguf(&incomplete_path, &["double_blocks.0.img_mod.lin.weight"]);
3532
3533        // Pre-create the cache file
3534        let cache_path = super::embedding_patched_cache_path(&incomplete_path);
3535        std::fs::create_dir_all(cache_path.parent().unwrap()).unwrap();
3536        write_test_gguf(
3537            &cache_path,
3538            &["img_in.weight", "double_blocks.0.img_mod.lin.weight"],
3539        );
3540
3541        let progress = crate::progress::ProgressReporter::default();
3542        let result =
3543            super::ensure_gguf_embeddings(&incomplete_path, true, &progress, None).unwrap();
3544
3545        assert_eq!(result, cache_path, "should return cached file");
3546
3547        // Clean up
3548        std::fs::remove_dir_all(&dir).ok();
3549        std::fs::remove_file(&cache_path).ok();
3550        // Try to clean up cache parent dir (may fail if other tests use it)
3551        let _ = std::fs::remove_dir(cache_path.parent().unwrap());
3552    }
3553
3554    #[test]
3555    fn find_flux_reference_skips_schnell_when_dev_needed() {
3556        // Regression: if only flux-schnell is downloaded, a dev-family target
3557        // (e.g. ultrareal-v4:q8) would previously pick schnell as reference and
3558        // then fail mid-patch because schnell lacks guidance_in.
3559        let dir = std::env::temp_dir().join(format!(
3560            "mold-ref-picker-{}-{}",
3561            std::process::id(),
3562            std::time::SystemTime::now()
3563                .duration_since(std::time::UNIX_EPOCH)
3564                .unwrap()
3565                .as_nanos()
3566        ));
3567        let models_dir = dir.join("models");
3568        let schnell_dir = models_dir.join("flux-schnell-q8");
3569        std::fs::create_dir_all(&schnell_dir).unwrap();
3570        let schnell_path = schnell_dir.join("flux1-schnell-Q8_0.gguf");
3571
3572        // Schnell has img_in but not guidance_in — mirrors the real city96 schnell GGUF
3573        let mut schnell_tensors: Vec<&str> = super::FLUX_EMBEDDING_TENSORS.to_vec();
3574        schnell_tensors.push("double_blocks.0.img_mod.lin.weight");
3575        write_test_gguf(&schnell_path, &schnell_tensors);
3576
3577        // needs_guidance=true must reject the schnell-only state
3578        let result = super::find_flux_reference_gguf(true, Some(&models_dir));
3579        assert!(
3580            result.is_none(),
3581            "schnell must not be picked as reference for dev targets: got {result:?}"
3582        );
3583
3584        // needs_guidance=false (schnell target) accepts the schnell reference
3585        let result = super::find_flux_reference_gguf(false, Some(&models_dir));
3586        assert_eq!(result.as_deref(), Some(schnell_path.as_path()));
3587
3588        std::fs::remove_dir_all(&dir).ok();
3589    }
3590
3591    #[test]
3592    fn find_flux_reference_accepts_dev_candidate_with_guidance() {
3593        // Happy path for the needs_guidance branch: a dev reference that has
3594        // guidance_in is accepted; a dev reference lacking guidance (truncated
3595        // or swapped file) is rejected.
3596        let dir = std::env::temp_dir().join(format!(
3597            "mold-ref-dev-{}-{}",
3598            std::process::id(),
3599            std::time::SystemTime::now()
3600                .duration_since(std::time::UNIX_EPOCH)
3601                .unwrap()
3602                .as_nanos()
3603        ));
3604        let models_dir = dir.join("models");
3605        let dev_dir = models_dir.join("flux-dev-q8");
3606        std::fs::create_dir_all(&dev_dir).unwrap();
3607        let dev_path = dev_dir.join("flux1-dev-Q8_0.gguf");
3608
3609        // Reference without guidance — needs_guidance=true must reject it.
3610        let incomplete: Vec<&str> = super::FLUX_EMBEDDING_TENSORS.to_vec();
3611        write_test_gguf(&dev_path, &incomplete);
3612        assert!(
3613            super::find_flux_reference_gguf(true, Some(&models_dir)).is_none(),
3614            "dev candidate without guidance_in must be rejected for dev targets"
3615        );
3616
3617        // Now add guidance tensors — same path should be accepted.
3618        let mut complete: Vec<&str> = super::FLUX_EMBEDDING_TENSORS.to_vec();
3619        complete.extend_from_slice(super::FLUX_GUIDANCE_EMBEDDING_TENSORS);
3620        write_test_gguf(&dev_path, &complete);
3621        let picked = super::find_flux_reference_gguf(true, Some(&models_dir))
3622            .expect("complete dev reference must be accepted");
3623        assert_eq!(picked, dev_path);
3624
3625        // Schnell target (needs_guidance=false) also accepts the dev candidate.
3626        let picked = super::find_flux_reference_gguf(false, Some(&models_dir))
3627            .expect("dev candidate satisfies schnell targets too");
3628        assert_eq!(picked, dev_path);
3629
3630        std::fs::remove_dir_all(&dir).ok();
3631    }
3632
3633    #[test]
3634    fn find_flux_reference_accepts_krea_when_no_base_dev() {
3635        // flux-krea is a dev-family fine-tune shipped as complete GGUFs — it
3636        // should serve as a reference for city96-format fine-tunes (UltraReal,
3637        // etc.) even when the base flux-dev GGUF isn't downloaded.
3638        let dir = std::env::temp_dir().join(format!(
3639            "mold-ref-krea-{}-{}",
3640            std::process::id(),
3641            std::time::SystemTime::now()
3642                .duration_since(std::time::UNIX_EPOCH)
3643                .unwrap()
3644                .as_nanos()
3645        ));
3646        let models_dir = dir.join("models");
3647        let krea_dir = models_dir.join("flux-krea-q8");
3648        std::fs::create_dir_all(&krea_dir).unwrap();
3649        let krea_path = krea_dir.join("flux1-krea-dev-Q8_0.gguf");
3650
3651        let mut complete: Vec<&str> = super::FLUX_EMBEDDING_TENSORS.to_vec();
3652        complete.extend_from_slice(super::FLUX_GUIDANCE_EMBEDDING_TENSORS);
3653        write_test_gguf(&krea_path, &complete);
3654
3655        let picked = super::find_flux_reference_gguf(true, Some(&models_dir))
3656            .expect("complete flux-krea reference must be accepted for dev targets");
3657        assert_eq!(picked, krea_path);
3658
3659        std::fs::remove_dir_all(&dir).ok();
3660    }
3661
3662    #[test]
3663    fn embedding_tensor_names_are_exhaustive() {
3664        // Verify the const arrays cover all non-diffusion-block tensors that
3665        // Flux::new() in quantized_model.rs expects (lines 378-416).
3666        // The model loads: img_in, txt_in, time_in, vector_in, guidance_in (optional),
3667        // double_blocks, single_blocks, final_layer, pe_embedder (computed, no tensors).
3668        // txt_in is present in city96 GGUFs. double/single/final are the diffusion blocks.
3669        // Only the embedding layers (img_in, time_in, vector_in, guidance_in) are missing.
3670        let all_embedding_names: Vec<&str> = super::FLUX_EMBEDDING_TENSORS
3671            .iter()
3672            .chain(super::FLUX_GUIDANCE_EMBEDDING_TENSORS.iter())
3673            .copied()
3674            .collect();
3675
3676        // img_in: linear (weight + bias)
3677        assert!(all_embedding_names.contains(&"img_in.weight"));
3678        assert!(all_embedding_names.contains(&"img_in.bias"));
3679
3680        // time_in: MlpEmbedder (in_layer + out_layer, each with weight + bias)
3681        assert!(all_embedding_names.contains(&"time_in.in_layer.weight"));
3682        assert!(all_embedding_names.contains(&"time_in.in_layer.bias"));
3683        assert!(all_embedding_names.contains(&"time_in.out_layer.weight"));
3684        assert!(all_embedding_names.contains(&"time_in.out_layer.bias"));
3685
3686        // vector_in: MlpEmbedder
3687        assert!(all_embedding_names.contains(&"vector_in.in_layer.weight"));
3688        assert!(all_embedding_names.contains(&"vector_in.in_layer.bias"));
3689        assert!(all_embedding_names.contains(&"vector_in.out_layer.weight"));
3690        assert!(all_embedding_names.contains(&"vector_in.out_layer.bias"));
3691
3692        // guidance_in: MlpEmbedder (dev only)
3693        assert!(all_embedding_names.contains(&"guidance_in.in_layer.weight"));
3694        assert!(all_embedding_names.contains(&"guidance_in.in_layer.bias"));
3695        assert!(all_embedding_names.contains(&"guidance_in.out_layer.weight"));
3696        assert!(all_embedding_names.contains(&"guidance_in.out_layer.bias"));
3697
3698        // Total: 14 tensors
3699        assert_eq!(all_embedding_names.len(), 14);
3700    }
3701}