1use anyhow::{bail, Result};
2use candle_core::{DType, Device, IndexOp, Tensor};
3use candle_nn::VarBuilder;
4use candle_transformers::models::flux;
5use candle_transformers::quantized_var_builder;
6use mold_core::{GenerateRequest, GenerateResponse, ImageData, ModelPaths};
7use std::collections::{BTreeMap, HashMap};
8use std::path::{Path, PathBuf};
9use std::sync::{Arc, Mutex};
10use std::time::Instant;
11
12use crate::cache::{
13 clear_cache, prompt_text_key, restore_cached_tensor_pair, store_cached_tensor_pair,
14 CachedTensorPair, LruCache, DEFAULT_PROMPT_CACHE_CAPACITY,
15};
16use crate::device::{
17 check_memory_budget, effective_device_ref, fmt_gb, free_vram_bytes, memory_status_string,
18 preflight_memory_check, should_offload, should_use_gpu, usable_free_vram_bytes,
19 CLIP_VRAM_THRESHOLD, MIN_OFFLOAD_VRAM,
20};
21use crate::encoders;
22use crate::engine::{rand_seed, InferenceEngine, LoadStrategy, OptionRestoreGuard};
23use crate::engine_base::EngineBase;
24use crate::image::{build_output_metadata, encode_image};
25use crate::progress::{ProgressCallback, ProgressReporter};
26
27use super::transformer::FluxTransformer;
28
29fn flux_transformer_var_builder<'a>(vb: VarBuilder<'a>) -> VarBuilder<'a> {
32 if vb.contains_tensor("img_in.weight") {
33 vb
34 } else if vb.contains_tensor("model.diffusion_model.img_in.weight") {
35 vb.pp("model.diffusion_model")
36 } else if vb.contains_tensor("diffusion_model.img_in.weight") {
37 vb.pp("diffusion_model")
38 } else {
39 vb
40 }
41}
42
43fn flux_vae_var_builder<'a>(vb: VarBuilder<'a>) -> VarBuilder<'a> {
47 if vb.contains_tensor("encoder.conv_in.weight") {
48 vb
49 } else if vb.contains_tensor("first_stage_model.encoder.conv_in.weight") {
50 vb.pp("first_stage_model")
51 } else if vb.contains_tensor("vae.encoder.conv_in.weight") {
52 vb.pp("vae")
53 } else {
54 vb
55 }
56}
57
58fn flux_safetensors_transformer_is_fp8(path: &std::path::Path) -> Result<bool> {
62 let tensors = unsafe { candle_core::safetensors::MmapedSafetensors::multi(&[path])? };
63 for key in [
64 "img_in.weight",
65 "model.diffusion_model.img_in.weight",
66 "diffusion_model.img_in.weight",
67 ] {
68 if let Ok(tensor) = tensors.load(key, &Device::Cpu) {
69 return Ok(tensor.dtype() == DType::F8E4M3);
70 }
71 }
72 Ok(false)
73}
74
75fn flux_runtime_dtype(is_cuda: bool, is_quantized: bool, transformer_is_fp8: bool) -> DType {
76 if is_quantized {
77 if is_cuda {
78 DType::BF16
79 } else {
80 DType::F32
81 }
82 } else if is_cuda && transformer_is_fp8 {
83 DType::F16
87 } else if is_cuda {
88 DType::BF16
89 } else {
90 DType::F32
91 }
92}
93
94fn fp8_gguf_cache_path(path: &Path) -> PathBuf {
99 use std::io::{Read, Seek, SeekFrom};
100 let stem = path
101 .file_stem()
102 .and_then(|s| s.to_str())
103 .unwrap_or("transformer");
104 let size = std::fs::metadata(path).map(|m| m.len()).unwrap_or(0);
105 let sample_offset = size / 4;
109 let content_hash = std::fs::File::open(path)
110 .and_then(|mut f| {
111 f.seek(SeekFrom::Start(sample_offset))?;
112 let mut buf = vec![0u8; 4096];
113 let n = f.read(&mut buf)?;
114 buf.truncate(n);
115 Ok(buf)
116 })
117 .map(|buf| {
118 let mut h: u64 = 0xcbf2_9ce4_8422_2325; for &b in &buf {
120 h ^= b as u64;
121 h = h.wrapping_mul(0x0100_0000_01b3); }
123 format!("{h:016x}")
124 })
125 .unwrap_or_else(|_| "0".to_string());
126 let cache_root = mold_core::Config::mold_dir()
127 .unwrap_or_else(|| PathBuf::from(".mold"))
128 .join("cache")
129 .join("flux-q8");
130 cache_root.join(format!("{stem}-{size}-{content_hash}.q8_0.gguf"))
131}
132
133fn q8_0_can_quantize_dims(dims: &[usize]) -> bool {
134 if dims.len() < 2 {
135 return false;
136 }
137 let block_size = candle_core::quantized::GgmlDType::Q8_0.block_size();
138 dims.last()
139 .is_some_and(|last_dim| *last_dim >= block_size && *last_dim % block_size == 0)
140}
141
142fn fp8_cache_should_skip_tensor(name: &str, dims: &[usize]) -> bool {
143 dims.is_empty() || name.starts_with("text_encoders.")
144}
145
146fn fp8_gguf_tmp_path(cache_path: &Path) -> PathBuf {
147 static NEXT_TMP: std::sync::atomic::AtomicU64 = std::sync::atomic::AtomicU64::new(0);
148 let seq = NEXT_TMP.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
149 cache_path.with_extension(format!("tmp.{}.{}", std::process::id(), seq))
150}
151
152fn ensure_fp8_gguf_cache(path: &Path, progress: &ProgressReporter) -> Result<PathBuf> {
158 let cache_path = fp8_gguf_cache_path(path);
159 if cache_path.exists() {
160 progress.info(&format!("Using cached Q8 GGUF: {}", cache_path.display()));
161 return Ok(cache_path);
162 }
163
164 let parent = cache_path
165 .parent()
166 .ok_or_else(|| anyhow::anyhow!("invalid cache path: {}", cache_path.display()))?;
167
168 let stem = path
175 .file_stem()
176 .and_then(|s| s.to_str())
177 .unwrap_or("transformer");
178 std::fs::create_dir_all(parent)?;
179 let old_v1 = parent.join(format!("{stem}.q8_0.gguf"));
180 if old_v1.exists() {
181 tracing::info!(path = %old_v1.display(), "removing v1 orphaned FP8 cache");
182 let _ = std::fs::remove_file(&old_v1);
183 }
184 if let Ok(entries) = std::fs::read_dir(parent) {
186 let v2_prefix = format!("{stem}-");
187 let suffix = ".q8_0.gguf";
188 for entry in entries.flatten() {
189 let name = entry.file_name();
190 let Some(name_str) = name.to_str() else {
191 continue;
192 };
193 if !name_str.starts_with(&v2_prefix) || !name_str.ends_with(suffix) {
194 continue;
195 }
196 let middle = &name_str[v2_prefix.len()..name_str.len() - suffix.len()];
198 if !middle.contains('-') && middle.chars().all(|c| c.is_ascii_digit()) {
201 tracing::info!(path = %entry.path().display(), "removing v2 orphaned FP8 cache");
202 let _ = std::fs::remove_file(entry.path());
203 }
204 }
205 }
206
207 progress.info("Converting FP8 checkpoint to Q8 GGUF cache (one-time, may take a few minutes)");
208 tracing::info!(
209 source = %path.display(),
210 cache = %cache_path.display(),
211 "converting FP8 safetensors to Q8_0 GGUF cache"
212 );
213
214 let tensors = unsafe { candle_core::safetensors::MmapedSafetensors::multi(&[path])? };
215
216 let prefix = if tensors.get("img_in.weight").is_ok() {
218 ""
219 } else if tensors.get("model.diffusion_model.img_in.weight").is_ok() {
220 "model.diffusion_model."
221 } else if tensors.get("diffusion_model.img_in.weight").is_ok() {
222 "diffusion_model."
223 } else {
224 ""
225 };
226
227 let all_names: Vec<String> = tensors
229 .tensors()
230 .into_iter()
231 .map(|(name, _)| name)
232 .collect();
233
234 let mut qtensors: Vec<(String, candle_core::quantized::QTensor)> = Vec::new();
235
236 let total = all_names.len();
237 for (i, name) in all_names.iter().enumerate() {
238 if (i + 1) % 50 == 0 || i + 1 == total {
239 progress.info(&format!("Quantizing tensor {}/{total}", i + 1));
240 }
241
242 let tensor = tensors.load(name, &Device::Cpu)?;
243 let out_name = if !prefix.is_empty() && name.starts_with(prefix) {
245 name[prefix.len()..].to_string()
246 } else {
247 name.clone()
248 };
249
250 if fp8_cache_should_skip_tensor(&out_name, tensor.dims()) {
251 continue;
252 }
253
254 let can_quantize = q8_0_can_quantize_dims(tensor.dims());
255
256 let qt = if can_quantize {
257 candle_core::quantized::QTensor::quantize(
258 &tensor,
259 candle_core::quantized::GgmlDType::Q8_0,
260 )?
261 } else {
262 candle_core::quantized::QTensor::quantize(
264 &tensor,
265 candle_core::quantized::GgmlDType::F32,
266 )?
267 };
268 qtensors.push((out_name, qt));
269 }
270
271 let tmp_path = fp8_gguf_tmp_path(&cache_path);
273 let write_result = (|| -> Result<()> {
274 let file = std::fs::File::create(&tmp_path)?;
275 let mut writer = std::io::BufWriter::new(file);
276 let tensor_refs: Vec<(&str, &candle_core::quantized::QTensor)> =
277 qtensors.iter().map(|(n, q)| (n.as_str(), q)).collect();
278 candle_core::quantized::gguf_file::write(&mut writer, &[], &tensor_refs)?;
279 Ok(())
280 })();
281 if let Err(e) = write_result {
282 let _ = std::fs::remove_file(&tmp_path);
283 return Err(e);
284 }
285 if cache_path.exists() {
286 let _ = std::fs::remove_file(&tmp_path);
287 progress.info(&format!("Using cached Q8 GGUF: {}", cache_path.display()));
288 return Ok(cache_path);
289 }
290 std::fs::rename(&tmp_path, &cache_path)?;
291
292 progress.info(&format!("Q8 GGUF cache created: {}", cache_path.display()));
293 tracing::info!(cache = %cache_path.display(), "FP8→Q8_0 GGUF cache created");
294 Ok(cache_path)
295}
296
297const FLUX_EMBEDDING_TENSORS: &[&str] = &[
301 "img_in.weight",
302 "img_in.bias",
303 "time_in.in_layer.weight",
304 "time_in.in_layer.bias",
305 "time_in.out_layer.weight",
306 "time_in.out_layer.bias",
307 "vector_in.in_layer.weight",
308 "vector_in.in_layer.bias",
309 "vector_in.out_layer.weight",
310 "vector_in.out_layer.bias",
311];
312
313const FLUX_GUIDANCE_EMBEDDING_TENSORS: &[&str] = &[
315 "guidance_in.in_layer.weight",
316 "guidance_in.in_layer.bias",
317 "guidance_in.out_layer.weight",
318 "guidance_in.out_layer.bias",
319];
320
321fn gguf_has_embeddings(path: &Path) -> Result<bool> {
328 let mut file = std::fs::File::open(path)?;
329 let content = candle_core::quantized::gguf_file::Content::read(&mut file)?;
330 Ok(content.tensor_infos.contains_key("img_in.weight"))
331}
332
333fn gguf_has_guidance(path: &Path) -> Result<bool> {
336 let mut file = std::fs::File::open(path)?;
337 let content = candle_core::quantized::gguf_file::Content::read(&mut file)?;
338 Ok(content
339 .tensor_infos
340 .contains_key("guidance_in.in_layer.weight"))
341}
342
343fn find_flux_reference_gguf(
353 needs_guidance: bool,
354 models_dir_override: Option<&Path>,
355) -> Option<PathBuf> {
356 let config = mold_core::Config::load_or_default();
357 let models_dir = models_dir_override
358 .map(PathBuf::from)
359 .unwrap_or_else(|| config.resolved_models_dir());
360
361 let mut candidates: Vec<&str> = vec![
367 "flux-dev:q8",
368 "flux-dev:q6",
369 "flux-dev:q4",
370 "flux-krea:q8",
371 "flux-krea:q6",
372 "flux-krea:q4",
373 ];
374 if !needs_guidance {
375 candidates.extend(["flux-schnell:q8", "flux-schnell:q4"]);
376 }
377
378 for name in candidates {
379 let Some(manifest) = mold_core::manifest::find_manifest(name) else {
380 continue;
381 };
382 let Some(xformer_file) = manifest
384 .files
385 .iter()
386 .find(|f| f.component == mold_core::manifest::ModelComponent::Transformer)
387 else {
388 continue;
389 };
390 let xformer_path =
391 models_dir.join(mold_core::manifest::storage_path(manifest, xformer_file));
392 if !xformer_path.exists() {
393 continue;
394 }
395 match gguf_has_embeddings(&xformer_path) {
397 Ok(true) => {
398 if needs_guidance {
399 match gguf_has_guidance(&xformer_path) {
400 Ok(true) => {}
401 Ok(false) => {
402 tracing::debug!(
403 model = name,
404 "reference candidate lacks guidance_in, skipping for dev target"
405 );
406 continue;
407 }
408 Err(e) => {
409 tracing::debug!(
410 model = name,
411 err = %e,
412 "failed to probe guidance tensors"
413 );
414 continue;
415 }
416 }
417 }
418 tracing::info!(
419 reference = %xformer_path.display(),
420 model = name,
421 needs_guidance,
422 "found reference FLUX GGUF with embeddings"
423 );
424 return Some(xformer_path);
425 }
426 Ok(false) => {
427 tracing::debug!(
428 model = name,
429 "reference candidate also missing embeddings, skipping"
430 );
431 }
432 Err(e) => {
433 tracing::debug!(model = name, err = %e, "failed to probe reference candidate");
434 }
435 }
436 }
437 None
438}
439
440fn embedding_patched_cache_path(path: &Path) -> PathBuf {
443 use std::io::{Read, Seek, SeekFrom};
444 let stem = path
445 .file_stem()
446 .and_then(|s| s.to_str())
447 .unwrap_or("transformer");
448 let size = std::fs::metadata(path).map(|m| m.len()).unwrap_or(0);
449 let sample_offset = size / 4;
450 let content_hash = std::fs::File::open(path)
451 .and_then(|mut f| {
452 f.seek(SeekFrom::Start(sample_offset))?;
453 let mut buf = vec![0u8; 4096];
454 let n = f.read(&mut buf)?;
455 buf.truncate(n);
456 Ok(buf)
457 })
458 .map(|buf| {
459 let mut h: u64 = 0xcbf2_9ce4_8422_2325;
460 for &b in &buf {
461 h ^= b as u64;
462 h = h.wrapping_mul(0x0100_0000_01b3);
463 }
464 format!("{h:016x}")
465 })
466 .unwrap_or_else(|_| "0".to_string());
467 let cache_root = mold_core::Config::mold_dir()
468 .unwrap_or_else(|| PathBuf::from(".mold"))
469 .join("cache")
470 .join("flux-embeddings");
471 cache_root.join(format!("{stem}-{size}-{content_hash}.patched.gguf"))
472}
473
474fn ensure_gguf_embeddings(
488 path: &Path,
489 is_schnell: bool,
490 progress: &ProgressReporter,
491 models_dir_override: Option<&Path>,
492) -> Result<PathBuf> {
493 let cache_path = embedding_patched_cache_path(path);
494 if cache_path.exists() {
495 progress.info(&format!(
496 "Using cached embedding-patched GGUF: {}",
497 cache_path.display()
498 ));
499 return Ok(cache_path);
500 }
501
502 if gguf_has_embeddings(path)? {
504 return Ok(path.to_path_buf());
505 }
506
507 progress.info(
508 "GGUF is missing FLUX embedding layers (city96 format) — patching from reference model",
509 );
510 tracing::info!(
511 path = %path.display(),
512 is_schnell,
513 "GGUF missing embedding layers, searching for reference model"
514 );
515
516 let source_name = path
517 .file_name()
518 .and_then(|n| n.to_str())
519 .unwrap_or("<unknown>");
520 let needs_guidance = !is_schnell;
521 let reference_path =
522 find_flux_reference_gguf(needs_guidance, models_dir_override).ok_or_else(|| {
523 let family = if needs_guidance { "dev" } else { "schnell" };
524 anyhow::anyhow!(
525 "{source_name} is a city96-format GGUF that ships only the diffusion \
526 blocks — its FLUX input embedding layers (img_in, time_in, vector_in{guidance}) \
527 must be sourced from a complete flux-{family} GGUF, but none is downloaded.\n\n\
528 To fix this:\n\n mold pull flux-dev:q8\n\n\
529 Then retry — mold will patch the incomplete GGUF from the reference.",
530 guidance = if needs_guidance { ", guidance_in" } else { "" },
531 )
532 })?;
533
534 let mut needed: Vec<&str> = FLUX_EMBEDDING_TENSORS.to_vec();
536 if !is_schnell {
537 needed.extend_from_slice(FLUX_GUIDANCE_EMBEDDING_TENSORS);
538 }
539
540 progress.info("Reading source GGUF tensors...");
542 let mut src_file = std::fs::File::open(path)?;
543 let src_content = candle_core::quantized::gguf_file::Content::read(&mut src_file)?;
544
545 progress.info(&format!(
547 "Extracting {} embedding tensors from reference: {}",
548 needed.len(),
549 reference_path
550 .file_name()
551 .and_then(|n| n.to_str())
552 .unwrap_or("?")
553 ));
554 let mut ref_file = std::fs::File::open(&reference_path)?;
555 let ref_content = candle_core::quantized::gguf_file::Content::read(&mut ref_file)?;
556
557 let cpu = Device::Cpu;
558
559 let mut qtensors: Vec<(String, candle_core::quantized::QTensor)> = Vec::new();
561 let total = src_content.tensor_infos.len();
562 for (i, name) in src_content.tensor_infos.keys().enumerate() {
563 if (i + 1) % 100 == 0 || i + 1 == total {
564 progress.info(&format!("Loading source tensor {}/{total}", i + 1));
565 }
566 let tensor = src_content.tensor(&mut src_file, name, &cpu)?;
567 qtensors.push((name.clone(), tensor));
568 }
569
570 let mut patched_count = 0usize;
572 for name in &needed {
573 if src_content.tensor_infos.contains_key(*name) {
574 continue; }
576 if !ref_content.tensor_infos.contains_key(*name) {
577 bail!(
578 "while patching {source_name}: the only downloaded reference ({}) \
579 is also missing '{name}'. This model needs a complete flux-dev GGUF \
580 — run 'mold pull flux-dev:q8' and retry.",
581 reference_path
582 .file_name()
583 .and_then(|n| n.to_str())
584 .unwrap_or("<unknown>"),
585 );
586 }
587 let tensor = ref_content.tensor(&mut ref_file, name, &cpu)?;
588 tracing::debug!(tensor = name, "patching embedding tensor from reference");
589 qtensors.push((name.to_string(), tensor));
590 patched_count += 1;
591 }
592
593 progress.info(&format!(
594 "Patched {patched_count} embedding tensors from reference"
595 ));
596
597 let parent = cache_path
599 .parent()
600 .ok_or_else(|| anyhow::anyhow!("invalid cache path: {}", cache_path.display()))?;
601 std::fs::create_dir_all(parent)?;
602 let tmp_path = cache_path.with_extension(format!("tmp.{}", std::process::id()));
603 let write_result = (|| -> Result<()> {
604 let file = std::fs::File::create(&tmp_path)?;
605 let mut writer = std::io::BufWriter::new(file);
606 let tensor_refs: Vec<(&str, &candle_core::quantized::QTensor)> =
607 qtensors.iter().map(|(n, q)| (n.as_str(), q)).collect();
608 candle_core::quantized::gguf_file::write(&mut writer, &[], &tensor_refs)?;
609 Ok(())
610 })();
611 if let Err(e) = write_result {
612 let _ = std::fs::remove_file(&tmp_path);
613 return Err(e);
614 }
615 std::fs::rename(&tmp_path, &cache_path)?;
616
617 progress.info(&format!(
618 "Embedding-patched GGUF cache created: {}",
619 cache_path.display()
620 ));
621 tracing::info!(
622 cache = %cache_path.display(),
623 patched_count,
624 "embedding-patched GGUF cache created"
625 );
626 Ok(cache_path)
627}
628
629fn flux_safetensors_var_builder<'a>(
630 path: &std::path::Path,
631 dtype: DType,
632 device: &Device,
633 component: &str,
634 progress: &ProgressReporter,
635) -> Result<VarBuilder<'a>> {
636 let aliases = flux_rms_norm_scale_aliases(path)?;
637 if aliases.is_empty() {
638 crate::weight_loader::load_safetensors_with_progress(
639 std::slice::from_ref(&path),
640 dtype,
641 device,
642 component,
643 progress,
644 )
645 } else {
646 tracing::info!(
647 alias_count = aliases.len(),
648 path = %path.display(),
649 "FLUX checkpoint uses RMSNorm .weight keys; aliasing .scale lookups"
650 );
651 crate::weight_loader::load_safetensors_with_aliases(
652 std::slice::from_ref(&path),
653 dtype,
654 device,
655 component,
656 progress,
657 aliases,
658 )
659 }
660}
661
662fn flux_rms_norm_scale_aliases(path: &std::path::Path) -> Result<BTreeMap<String, String>> {
663 let tensors = unsafe { candle_core::safetensors::MmapedSafetensors::multi(&[path])? };
664 let mut aliases = BTreeMap::new();
665 for prefix in ["", "model.diffusion_model.", "diffusion_model."] {
666 for i in 0..64 {
667 for stream in ["img_attn", "txt_attn"] {
668 for norm in ["query_norm", "key_norm"] {
669 let target = format!("{prefix}double_blocks.{i}.{stream}.norm.{norm}.scale");
670 let source = format!("{prefix}double_blocks.{i}.{stream}.norm.{norm}.weight");
671 if tensors.get(&target).is_err() && tensors.get(&source).is_ok() {
672 aliases.insert(target, source);
673 }
674 }
675 }
676 }
677 for i in 0..128 {
678 for norm in ["query_norm", "key_norm"] {
679 let target = format!("{prefix}single_blocks.{i}.norm.{norm}.scale");
680 let source = format!("{prefix}single_blocks.{i}.norm.{norm}.weight");
681 if tensors.get(&target).is_err() && tensors.get(&source).is_ok() {
682 aliases.insert(target, source);
683 }
684 }
685 }
686 }
687 Ok(aliases)
688}
689
690fn flux_lora_var_builder<'a>(
701 transformer_path: &Path,
702 loras: &[mold_core::LoraWeight],
703 dtype: DType,
704 device: &Device,
705 progress: &ProgressReporter,
706 delta_cache: Option<std::sync::Arc<std::sync::Mutex<super::lora::LoraDeltaCache>>>,
707) -> Result<VarBuilder<'a>> {
708 use super::lora;
709
710 let adapters: Vec<std::sync::Arc<lora::LoraAdapter>> = loras
711 .iter()
712 .map(|w| {
713 progress.info("Loading LoRA adapter");
714 let adapter = lora::get_or_load_adapter(Path::new(&w.path))?;
715 progress.info(&format!(
716 "LoRA: {} layers, rank {}, scale {:.2}",
717 adapter.layers.len(),
718 adapter.rank,
719 w.scale,
720 ));
721 anyhow::Ok(adapter)
722 })
723 .collect::<Result<_>>()?;
724
725 let specs: Vec<lora::LoraSpec<'_>> = adapters
726 .iter()
727 .zip(loras.iter())
728 .map(|(adapter, w)| lora::LoraSpec {
729 adapter: adapter.as_ref(),
730 scale: w.scale,
731 path_hash: lora_path_hash(&w.path),
732 })
733 .collect();
734
735 lora::lora_var_builder(
736 transformer_path,
737 &specs,
738 dtype,
739 device,
740 progress,
741 delta_cache,
742 )
743}
744
745fn lora_path_hash(path: &str) -> u64 {
749 use std::hash::{Hash, Hasher};
750 let mut hasher = std::collections::hash_map::DefaultHasher::new();
751 path.hash(&mut hasher);
752 hasher.finish()
753}
754
755fn flux_gguf_lora_var_builder(
757 transformer_path: &Path,
758 loras: &[mold_core::LoraWeight],
759 device: &Device,
760 progress: &ProgressReporter,
761 delta_cache: Option<std::sync::Arc<std::sync::Mutex<super::lora::LoraDeltaCache>>>,
762) -> Result<candle_transformers::quantized_var_builder::VarBuilder> {
763 use super::lora;
764
765 let adapters: Vec<std::sync::Arc<lora::LoraAdapter>> = loras
766 .iter()
767 .map(|w| {
768 progress.info("Loading LoRA adapter");
769 let adapter = lora::get_or_load_adapter(Path::new(&w.path))?;
770 progress.info(&format!(
771 "LoRA: {} layers, rank {}, scale {:.2}",
772 adapter.layers.len(),
773 adapter.rank,
774 w.scale,
775 ));
776 anyhow::Ok(adapter)
777 })
778 .collect::<Result<_>>()?;
779
780 let specs: Vec<lora::LoraSpec<'_>> = adapters
781 .iter()
782 .zip(loras.iter())
783 .map(|(adapter, w)| lora::LoraSpec {
784 adapter: adapter.as_ref(),
785 scale: w.scale,
786 path_hash: lora_path_hash(&w.path),
787 })
788 .collect();
789
790 lora::gguf_lora_var_builder(transformer_path, &specs, device, progress, delta_cache)
791}
792
793#[derive(Clone, Copy, Debug, Eq, PartialEq)]
801enum LoraBypassMode {
802 Auto,
803 On,
804 Off,
805}
806
807impl LoraBypassMode {
808 fn from_env() -> Self {
809 match std::env::var("MOLD_LORA_BYPASS")
810 .ok()
811 .as_deref()
812 .map(str::trim)
813 .map(str::to_ascii_lowercase)
814 .as_deref()
815 {
816 Some("on") | Some("1") | Some("true") => Self::On,
817 Some("off") | Some("0") | Some("false") => Self::Off,
818 _ => Self::Auto,
819 }
820 }
821}
822
823fn build_lora_registry(
834 loras: &[mold_core::LoraWeight],
835 cfg: &flux::model::Config,
836 device: &Device,
837 dtype: DType,
838 progress: &ProgressReporter,
839) -> Result<Option<super::lora_bypass::LoraRegistry>> {
840 use super::lora;
841 use super::lora_bypass;
842
843 if loras.is_empty() {
844 return Ok(None);
845 }
846
847 let adapters: Vec<lora::LoraAdapter> = loras
848 .iter()
849 .map(|w| {
850 progress.info("Loading LoRA adapter (bypass)");
851 let adapter = lora::LoraAdapter::load(Path::new(&w.path))?;
852 progress.info(&format!(
853 "LoRA: {} layers, rank {}, scale {:.2}",
854 adapter.layers.len(),
855 adapter.rank,
856 w.scale,
857 ));
858 anyhow::Ok(adapter)
859 })
860 .collect::<Result<_>>()?;
861
862 let specs: Vec<lora::LoraSpec<'_>> = adapters
863 .iter()
864 .zip(loras.iter())
865 .map(|(adapter, w)| lora::LoraSpec {
866 adapter,
867 scale: w.scale,
868 path_hash: lora_path_hash(&w.path),
869 })
870 .collect();
871
872 let h = cfg.hidden_size;
876 let mlp_sz = (h as f64 * cfg.mlp_ratio) as usize;
877 let mut linear_out_dims: std::collections::HashMap<String, usize> =
878 std::collections::HashMap::new();
879 for idx in 0..cfg.depth {
880 linear_out_dims.insert(format!("double_blocks.{idx}.img_attn.qkv.weight"), 3 * h);
882 linear_out_dims.insert(format!("double_blocks.{idx}.txt_attn.qkv.weight"), 3 * h);
883 }
884 for idx in 0..cfg.depth_single_blocks {
885 linear_out_dims.insert(
887 format!("single_blocks.{idx}.linear1.weight"),
888 3 * h + mlp_sz,
889 );
890 }
891
892 let registry = lora_bypass::build_registry(&specs, &linear_out_dims, device, dtype)?;
893 progress.info(&format!(
894 "LoRA bypass: {} target tensors, adapters resident on {device:?}",
895 registry.len()
896 ));
897 Ok(Some(registry))
898}
899
900pub(crate) fn effective_loras(req: &mold_core::GenerateRequest) -> Vec<mold_core::LoraWeight> {
913 const ZERO_SCALE_EPS: f64 = 1e-8;
918
919 let raw: Vec<mold_core::LoraWeight> = if let Some(plural) = &req.loras {
920 if !plural.is_empty() {
921 plural.clone()
922 } else {
923 req.lora.iter().cloned().collect()
924 }
925 } else {
926 req.lora.iter().cloned().collect()
927 };
928
929 raw.into_iter()
930 .filter(|w| {
931 let keep = w.scale.abs() > ZERO_SCALE_EPS;
932 if !keep {
933 tracing::debug!(
934 path = w.path.as_str(),
935 scale = w.scale,
936 "dropping zero-scale LoRA from effective stack"
937 );
938 }
939 keep
940 })
941 .collect()
942}
943
944struct LoadedFlux {
950 flux_model: Option<FluxTransformer>,
952 t5: encoders::t5::T5Encoder,
953 clip: encoders::clip::ClipEncoder,
954 vae: flux::autoencoder::AutoEncoder,
955 device: Device,
957 dtype: DType,
958 vae_dtype: DType,
964 is_schnell: bool,
965 is_quantized: bool,
967 transformer_path: PathBuf,
969 t5_encoder_path: std::path::PathBuf,
971}
972
973#[derive(Clone, PartialEq, Eq)]
977struct LoraFingerprint {
978 path_hash: u64,
979 scale_bits: u64,
980}
981
982impl LoraFingerprint {
983 fn from_lora_weight(lora: &mold_core::LoraWeight) -> Self {
984 Self {
985 path_hash: lora_path_hash(&lora.path),
986 scale_bits: lora.scale.to_bits(),
987 }
988 }
989}
990
991fn fingerprint_stack(loras: &[mold_core::LoraWeight]) -> Vec<LoraFingerprint> {
997 loras
998 .iter()
999 .map(LoraFingerprint::from_lora_weight)
1000 .collect()
1001}
1002
1003pub struct FluxEngine {
1005 base: EngineBase<LoadedFlux>,
1006 is_schnell_override: Option<bool>,
1008 t5_variant: Option<String>,
1010 prompt_cache: Mutex<LruCache<String, CachedTensorPair>>,
1011 transformer_is_fp8: Option<bool>,
1013 cached_transformer_path: Option<PathBuf>,
1016 offload: bool,
1018 active_lora: Vec<LoraFingerprint>,
1022 lora_delta_cache: Arc<Mutex<super::lora::LoraDeltaCache>>,
1024 shared_pool: Option<Arc<Mutex<crate::shared_pool::SharedPool>>>,
1026 pending_placement: Option<mold_core::types::DevicePlacement>,
1029}
1030
1031impl FluxEngine {
1032 #[allow(clippy::too_many_arguments)]
1037 pub fn new(
1038 model_name: String,
1039 paths: ModelPaths,
1040 is_schnell_override: Option<bool>,
1041 t5_variant: Option<String>,
1042 load_strategy: LoadStrategy,
1043 gpu_ordinal: usize,
1044 offload: bool,
1045 shared_pool: Option<Arc<Mutex<crate::shared_pool::SharedPool>>>,
1046 ) -> Self {
1047 Self {
1048 base: EngineBase::new(model_name, paths, load_strategy, gpu_ordinal),
1049 is_schnell_override,
1050 t5_variant,
1051 prompt_cache: Mutex::new(LruCache::new(DEFAULT_PROMPT_CACHE_CAPACITY)),
1052 transformer_is_fp8: None,
1053 cached_transformer_path: None,
1054 offload,
1055 active_lora: Vec::new(),
1056 lora_delta_cache: Arc::new(Mutex::new(super::lora::LoraDeltaCache::new())),
1057 shared_pool,
1058 pending_placement: None,
1059 }
1060 }
1061
1062 fn lora_delta_cache_handle(&self) -> Option<Arc<Mutex<super::lora::LoraDeltaCache>>> {
1069 if std::env::var("MOLD_FLUX_DELTA_CACHE")
1070 .map(|v| v == "0")
1071 .unwrap_or(false)
1072 {
1073 None
1074 } else {
1075 Some(self.lora_delta_cache.clone())
1076 }
1077 }
1078
1079 fn get_cached_tokenizer(&self, path: &std::path::Path) -> Option<Arc<tokenizers::Tokenizer>> {
1081 let pool = self.shared_pool.as_ref()?;
1082 let pool = pool.lock().unwrap();
1083 pool.get_tokenizer(&path.to_string_lossy())
1084 }
1085
1086 fn cache_tokenizer(&self, path: &std::path::Path, tokenizer: Arc<tokenizers::Tokenizer>) {
1088 if let Some(ref pool) = self.shared_pool {
1089 let mut pool = pool.lock().unwrap();
1090 pool.insert_tokenizer(path.to_string_lossy().into_owned(), tokenizer);
1091 }
1092 }
1093
1094 fn load_vae_var_builder<'a>(
1096 &self,
1097 dtype: DType,
1098 device: &Device,
1099 component: &str,
1100 ) -> Result<VarBuilder<'a>> {
1101 if let Some(pool) = &self.shared_pool {
1102 let cached = pool
1103 .lock()
1104 .unwrap()
1105 .load_cpu_tensors(std::slice::from_ref(&self.base.paths.vae))?;
1106 let vb = crate::encoders::park::varbuilder_from_parked(cached.as_ref(), dtype, device);
1107 return Ok(flux_vae_var_builder(vb));
1108 }
1109
1110 let vb = crate::weight_loader::load_safetensors_with_progress(
1111 std::slice::from_ref(&self.base.paths.vae),
1112 dtype,
1113 device,
1114 component,
1115 &self.base.progress,
1116 )?;
1117 Ok(flux_vae_var_builder(vb))
1118 }
1119
1120 fn get_cached_safetensors(&self, path: &Path) -> Result<Option<Arc<HashMap<String, Tensor>>>> {
1121 let Some(pool) = &self.shared_pool else {
1122 return Ok(None);
1123 };
1124 let paths = [path];
1125 pool.lock().unwrap().load_safetensors_cpu_tensors(&paths)
1126 }
1127
1128 fn restore_prompt_cache(
1129 progress: &ProgressReporter,
1130 prompt_cache: &Mutex<LruCache<String, CachedTensorPair>>,
1131 prompt: &str,
1132 device: &Device,
1133 dtype: DType,
1134 ) -> Result<Option<(candle_core::Tensor, candle_core::Tensor)>> {
1135 let restored =
1136 restore_cached_tensor_pair(prompt_cache, &prompt_text_key(prompt), device, dtype)?;
1137 let Some(restored) = restored else {
1138 return Ok(None);
1139 };
1140 progress.cache_hit("prompt conditioning");
1141 Ok(Some(restored))
1142 }
1143
1144 fn store_prompt_cache(
1145 prompt_cache: &Mutex<LruCache<String, CachedTensorPair>>,
1146 prompt: &str,
1147 t5_emb: &candle_core::Tensor,
1148 clip_emb: &candle_core::Tensor,
1149 ) -> Result<()> {
1150 store_cached_tensor_pair(prompt_cache, prompt_text_key(prompt), t5_emb, clip_emb)
1151 }
1152}
1153
1154pub(crate) fn park_cond_to_cpu(tensor: &candle_core::Tensor) -> Result<candle_core::Tensor> {
1163 if tensor.device().is_cpu() {
1164 return Ok(tensor.clone());
1165 }
1166 Ok(tensor.to_device(&Device::Cpu)?)
1167}
1168
1169impl FluxEngine {
1170 fn detect_is_schnell(&self) -> bool {
1172 self.is_schnell_override.unwrap_or_else(|| {
1173 self.base.model_name.contains("schnell")
1174 || self
1175 .base
1176 .paths
1177 .transformer
1178 .file_name()
1179 .and_then(|n| n.to_str())
1180 .map(|n| n.to_ascii_lowercase().contains("schnell"))
1181 .unwrap_or(false)
1182 })
1183 }
1184
1185 fn check_transformer_is_fp8(&mut self, is_quantized: bool) -> bool {
1189 if let Some(cached) = self.transformer_is_fp8 {
1190 return cached;
1191 }
1192 let result = !is_quantized
1193 && flux_safetensors_transformer_is_fp8(&self.base.paths.transformer).unwrap_or(false);
1194 self.transformer_is_fp8 = Some(result);
1195 result
1196 }
1197
1198 fn detect_is_quantized(&self) -> bool {
1199 self.base
1200 .paths
1201 .transformer
1202 .extension()
1203 .and_then(|e| e.to_str())
1204 .map(|e| e.eq_ignore_ascii_case("gguf"))
1205 .unwrap_or(false)
1206 }
1207
1208 fn validate_paths(
1210 &self,
1211 ) -> Result<(
1212 std::path::PathBuf,
1213 std::path::PathBuf,
1214 std::path::PathBuf,
1215 std::path::PathBuf,
1216 )> {
1217 let t5_encoder_path = self
1218 .base
1219 .paths
1220 .t5_encoder
1221 .as_ref()
1222 .ok_or_else(|| anyhow::anyhow!("T5 encoder path required for FLUX models"))?
1223 .clone();
1224 let t5_tokenizer_path = self
1225 .base
1226 .paths
1227 .t5_tokenizer
1228 .as_ref()
1229 .ok_or_else(|| anyhow::anyhow!("T5 tokenizer path required for FLUX models"))?
1230 .clone();
1231 let clip_encoder_path = self
1232 .base
1233 .paths
1234 .clip_encoder
1235 .as_ref()
1236 .ok_or_else(|| anyhow::anyhow!("CLIP encoder path required for FLUX models"))?
1237 .clone();
1238 let clip_tokenizer_path = self
1239 .base
1240 .paths
1241 .clip_tokenizer
1242 .as_ref()
1243 .ok_or_else(|| anyhow::anyhow!("CLIP tokenizer path required for FLUX models"))?
1244 .clone();
1245
1246 for (label, path) in [
1247 ("transformer", &self.base.paths.transformer),
1248 ("vae", &self.base.paths.vae),
1249 ("t5_encoder", &t5_encoder_path),
1250 ("clip_encoder", &clip_encoder_path),
1251 ("t5_tokenizer", &t5_tokenizer_path),
1252 ("clip_tokenizer", &clip_tokenizer_path),
1253 ] {
1254 if !path.exists() {
1255 bail!("{label} file not found: {}", path.display());
1256 }
1257 }
1258
1259 Ok((
1260 t5_encoder_path,
1261 t5_tokenizer_path,
1262 clip_encoder_path,
1263 clip_tokenizer_path,
1264 ))
1265 }
1266
1267 pub fn load(&mut self) -> Result<()> {
1273 self.active_lora = Vec::new();
1274 if self.base.loaded.is_some() {
1275 return Ok(());
1276 }
1277
1278 if self.defers_eager_load() {
1283 return Ok(());
1284 }
1285
1286 let is_schnell = self.detect_is_schnell();
1287 tracing::info!(model = %self.base.model_name, "loading FLUX model components...");
1288
1289 let (t5_encoder_path, t5_tokenizer_path, clip_encoder_path, clip_tokenizer_path) =
1290 self.validate_paths()?;
1291
1292 let cpu = Device::Cpu;
1293 let transformer_ref = effective_device_ref(
1294 self.pending_placement.as_ref(),
1295 |adv| Some(adv.transformer),
1296 false,
1297 );
1298 let device = crate::device::resolve_device(Some(transformer_ref), || {
1299 crate::device::create_device(self.base.gpu_ordinal, &self.base.progress)
1300 })?;
1301 let mut is_quantized = self.detect_is_quantized();
1302 let transformer_is_fp8 = self.check_transformer_is_fp8(is_quantized);
1303
1304 let transformer_path = if transformer_is_fp8 {
1308 let p = ensure_fp8_gguf_cache(&self.base.paths.transformer, &self.base.progress)?;
1309 is_quantized = true;
1310 p
1311 } else {
1312 self.base.paths.transformer.clone()
1313 };
1314
1315 let transformer_path = if is_quantized {
1317 ensure_gguf_embeddings(&transformer_path, is_schnell, &self.base.progress, None)?
1318 } else {
1319 transformer_path
1320 };
1321
1322 let gpu_dtype = flux_runtime_dtype(device.is_cuda(), is_quantized, false);
1323
1324 tracing::info!("GPU device: {:?}, GPU dtype: {:?}", device, gpu_dtype);
1325
1326 if !is_quantized {
1331 let xformer_size = std::fs::metadata(&transformer_path)
1332 .map(|m| m.len())
1333 .unwrap_or(0);
1334 let free = usable_free_vram_bytes(self.base.gpu_ordinal).unwrap_or(0);
1337 if free > 0 && xformer_size > free {
1338 bail!(
1339 "transformer ({:.1} GB) exceeds available VRAM ({:.1} GB) — \
1340 use a quantized model (q8/q4) instead of full-precision for this GPU",
1341 xformer_size as f64 / 1e9,
1342 free as f64 / 1e9,
1343 );
1344 }
1345 }
1346
1347 let flux_cfg = if is_schnell {
1348 flux::model::Config::schnell()
1349 } else {
1350 flux::model::Config::dev()
1351 };
1352
1353 let xformer_label = if is_quantized {
1354 "Loading FLUX transformer (GPU, quantized)"
1355 } else {
1356 "Loading FLUX transformer (GPU, BF16)"
1357 };
1358 self.base.progress.stage_start(xformer_label);
1359 let xformer_stage = Instant::now();
1360 tracing::info!(
1361 path = %transformer_path.display(),
1362 quantized = is_quantized,
1363 "loading FLUX transformer on GPU..."
1364 );
1365
1366 let flux_model = if is_quantized {
1367 let vb = quantized_var_builder::VarBuilder::from_gguf(&transformer_path, &device)?;
1368 FluxTransformer::Quantized(flux::quantized_model::Flux::new(&flux_cfg, vb)?)
1369 } else {
1370 let flux_vb = flux_transformer_var_builder(flux_safetensors_var_builder(
1371 &transformer_path,
1372 gpu_dtype,
1373 &device,
1374 "FLUX transformer",
1375 &self.base.progress,
1376 )?);
1377 FluxTransformer::BF16(flux::model::Flux::new(&flux_cfg, flux_vb)?)
1378 };
1379 self.base
1380 .progress
1381 .stage_done(xformer_label, xformer_stage.elapsed());
1382 tracing::info!("FLUX transformer loaded on GPU");
1383
1384 let vae_ref =
1387 effective_device_ref(self.pending_placement.as_ref(), |adv| Some(adv.vae), false);
1388 let vae_device = crate::device::resolve_device(Some(vae_ref), || Ok(device.clone()))?;
1389 self.base.progress.stage_start("Loading VAE (GPU)");
1390 let vae_stage = Instant::now();
1391 tracing::info!(path = %self.base.paths.vae.display(), "loading VAE on GPU...");
1392 let vae_dtype = crate::device::resolve_vae_dtype(gpu_dtype);
1394 let vae_vb = self.load_vae_var_builder(vae_dtype, &vae_device, "VAE")?;
1395 let vae_cfg = if is_schnell {
1396 flux::autoencoder::Config::schnell()
1397 } else {
1398 flux::autoencoder::Config::dev()
1399 };
1400 let vae = flux::autoencoder::AutoEncoder::new(&vae_cfg, vae_vb)?;
1401 self.base
1402 .progress
1403 .stage_done("Loading VAE (GPU)", vae_stage.elapsed());
1404 tracing::info!("VAE loaded on GPU");
1405
1406 let free_raw = free_vram_bytes(self.base.gpu_ordinal).unwrap_or(0);
1412 let free = usable_free_vram_bytes(self.base.gpu_ordinal).unwrap_or(0);
1413 if free_raw > 0 {
1414 self.base.progress.info(&format!(
1415 "Free VRAM after transformer+VAE: {}",
1416 fmt_gb(free_raw)
1417 ));
1418 tracing::info!(
1419 free_vram = free_raw,
1420 free_vram_usable = free,
1421 "free VRAM after loading transformer + VAE"
1422 );
1423 }
1424
1425 self.base.progress.stage_start("Selecting T5 encoder");
1427 let t5_resolve_start = Instant::now();
1428 let t5_preference = self.t5_variant.as_deref();
1429 let (resolved_t5_path, t5_on_gpu, _t5_auto_device_label) =
1430 crate::encoders::variant_resolution::resolve_t5_variant(
1431 &self.base.progress,
1432 t5_preference,
1433 &device,
1434 free,
1435 &t5_encoder_path,
1436 )?;
1437 self.base
1438 .progress
1439 .stage_done("Selecting T5 encoder", t5_resolve_start.elapsed());
1440 let t5_ref = effective_device_ref(self.pending_placement.as_ref(), |adv| adv.t5, true);
1442 let auto_t5_device = if t5_on_gpu {
1443 device.clone()
1444 } else {
1445 cpu.clone()
1446 };
1447 let t5_device_owned =
1448 crate::device::resolve_device(Some(t5_ref), || Ok(auto_t5_device.clone()))?;
1449 let t5_device = &t5_device_owned;
1450 let t5_on_gpu = !t5_device.is_cpu();
1451 let t5_device_label = if t5_on_gpu { "GPU" } else { "CPU" };
1452 let t5_dtype = if t5_on_gpu { gpu_dtype } else { DType::F32 };
1453
1454 let t5_stage_label = format!("Loading T5 encoder ({t5_device_label})");
1456 self.base.progress.stage_start(&t5_stage_label);
1457 let t5_stage = Instant::now();
1458 tracing::info!(
1459 path = %resolved_t5_path.display(),
1460 device = %t5_device_label,
1461 "loading T5 encoder..."
1462 );
1463 let cached_t5_tok = self.get_cached_tokenizer(&t5_tokenizer_path);
1464 let cached_t5_tensors = self.get_cached_safetensors(&resolved_t5_path)?;
1465 let t5 = encoders::t5::T5Encoder::load_with_tokenizer_and_tensors(
1466 &resolved_t5_path,
1467 &t5_tokenizer_path,
1468 t5_device,
1469 t5_dtype,
1470 &self.base.progress,
1471 cached_t5_tok,
1472 cached_t5_tensors,
1473 )?;
1474 self.cache_tokenizer(&t5_tokenizer_path, t5.tokenizer_arc());
1475 self.base
1476 .progress
1477 .stage_done(&t5_stage_label, t5_stage.elapsed());
1478 tracing::info!(device = %t5_device_label, "T5 encoder loaded");
1479
1480 let free_after_t5 = usable_free_vram_bytes(self.base.gpu_ordinal).unwrap_or(0);
1483 let clip_on_gpu = should_use_gpu(
1484 device.is_cuda(),
1485 device.is_metal(),
1486 free_after_t5,
1487 CLIP_VRAM_THRESHOLD,
1488 );
1489 let clip_ref =
1490 effective_device_ref(self.pending_placement.as_ref(), |adv| adv.clip_l, true);
1491 let auto_clip_device = if clip_on_gpu {
1492 device.clone()
1493 } else {
1494 cpu.clone()
1495 };
1496 let clip_device_owned =
1497 crate::device::resolve_device(Some(clip_ref), || Ok(auto_clip_device.clone()))?;
1498 let clip_device = &clip_device_owned;
1499 let clip_on_gpu = !clip_device.is_cpu();
1500 let clip_dtype = if clip_on_gpu { gpu_dtype } else { DType::F32 };
1501 let clip_device_label = if clip_on_gpu { "GPU" } else { "CPU" };
1502
1503 let clip_stage_label = format!("Loading CLIP encoder ({clip_device_label})");
1505 self.base.progress.stage_start(&clip_stage_label);
1506 let clip_stage = Instant::now();
1507 tracing::info!(
1508 path = %clip_encoder_path.display(),
1509 device = clip_device_label,
1510 "loading CLIP encoder..."
1511 );
1512 let cached_clip_tok = self.get_cached_tokenizer(&clip_tokenizer_path);
1513 let cached_clip_tensors = self.get_cached_safetensors(&clip_encoder_path)?;
1514 let clip = encoders::clip::ClipEncoder::load_with_tokenizer_and_tensors(
1515 &clip_encoder_path,
1516 &clip_tokenizer_path,
1517 clip_device,
1518 clip_dtype,
1519 &self.base.progress,
1520 cached_clip_tok,
1521 cached_clip_tensors,
1522 )?;
1523 self.cache_tokenizer(&clip_tokenizer_path, clip.tokenizer_arc());
1524 self.base
1525 .progress
1526 .stage_done(&clip_stage_label, clip_stage.elapsed());
1527 tracing::info!(device = clip_device_label, "CLIP encoder loaded");
1528
1529 self.base.loaded = Some(LoadedFlux {
1530 flux_model: Some(flux_model),
1531 t5,
1532 clip,
1533 vae,
1534 device,
1535 dtype: gpu_dtype,
1536 vae_dtype,
1537 is_schnell,
1538 is_quantized,
1539 transformer_path,
1540 t5_encoder_path: resolved_t5_path,
1541 });
1542
1543 tracing::info!(model = %self.base.model_name, "all model components loaded successfully");
1544 Ok(())
1545 }
1546
1547 fn generate_sequential(&mut self, req: &GenerateRequest) -> Result<GenerateResponse> {
1557 let is_schnell = self.detect_is_schnell();
1558 let mut is_quantized = self.detect_is_quantized();
1559
1560 let (t5_encoder_path, t5_tokenizer_path, clip_encoder_path, clip_tokenizer_path) =
1561 self.validate_paths()?;
1562
1563 if let Some(warning) = check_memory_budget(&self.base.paths, LoadStrategy::Sequential) {
1565 self.base.progress.info(&warning);
1566 }
1567
1568 let transformer_ref = effective_device_ref(
1569 self.pending_placement.as_ref(),
1570 |adv| Some(adv.transformer),
1571 false,
1572 );
1573 let device = crate::device::resolve_device(Some(transformer_ref), || {
1574 crate::device::create_device(self.base.gpu_ordinal, &self.base.progress)
1575 })?;
1576
1577 let transformer_path = if let Some(ref cached) = self.cached_transformer_path {
1579 if cached
1580 .extension()
1581 .and_then(|e| e.to_str())
1582 .map(|e| e.eq_ignore_ascii_case("gguf"))
1583 .unwrap_or(false)
1584 {
1585 is_quantized = true;
1586 }
1587 cached.clone()
1588 } else {
1589 let transformer_is_fp8 = self.check_transformer_is_fp8(is_quantized);
1590 let p = if transformer_is_fp8 {
1591 let p = ensure_fp8_gguf_cache(&self.base.paths.transformer, &self.base.progress)?;
1592 is_quantized = true;
1593 p
1594 } else {
1595 self.base.paths.transformer.clone()
1596 };
1597 let p = if is_quantized {
1599 ensure_gguf_embeddings(&p, is_schnell, &self.base.progress, None)?
1600 } else {
1601 p
1602 };
1603 self.cached_transformer_path = Some(p.clone());
1604 p
1605 };
1606
1607 let gpu_dtype = flux_runtime_dtype(device.is_cuda(), is_quantized, false);
1608
1609 let start = Instant::now();
1610 let seed = req.seed.unwrap_or_else(rand_seed);
1611
1612 let width = req.width as usize;
1613 let height = req.height as usize;
1614
1615 tracing::info!(
1616 prompt = %req.prompt,
1617 seed, width, height,
1618 steps = req.steps,
1619 "starting sequential FLUX generation"
1620 );
1621
1622 self.base
1623 .progress
1624 .info("Using sequential loading (load-use-drop) to minimize peak memory");
1625
1626 let (t5_emb, clip_emb) = if let Some((t5_emb, clip_emb)) = Self::restore_prompt_cache(
1627 &self.base.progress,
1628 &self.prompt_cache,
1629 &req.prompt,
1630 &device,
1631 gpu_dtype,
1632 )? {
1633 (t5_emb, clip_emb)
1634 } else {
1635 let free = usable_free_vram_bytes(self.base.gpu_ordinal).unwrap_or(0);
1638 self.base.progress.stage_start("Selecting T5 encoder");
1639 let t5_resolve_start = Instant::now();
1640 let t5_preference = self.t5_variant.as_deref();
1641 let (resolved_t5_path, t5_on_gpu, _t5_auto_device_label) =
1642 crate::encoders::variant_resolution::resolve_t5_variant(
1643 &self.base.progress,
1644 t5_preference,
1645 &device,
1646 free,
1647 &t5_encoder_path,
1648 )?;
1649 self.base
1650 .progress
1651 .stage_done("Selecting T5 encoder", t5_resolve_start.elapsed());
1652
1653 let t5_ref = effective_device_ref(self.pending_placement.as_ref(), |adv| adv.t5, true);
1654 let auto_t5_device = if t5_on_gpu {
1655 device.clone()
1656 } else {
1657 Device::Cpu
1658 };
1659 let t5_device_owned =
1660 crate::device::resolve_device(Some(t5_ref), || Ok(auto_t5_device.clone()))?;
1661 let t5_device = &t5_device_owned;
1662 let t5_on_gpu = !t5_device.is_cpu();
1663 let t5_device_label = if t5_on_gpu { "GPU" } else { "CPU" };
1664 let t5_dtype = if t5_on_gpu { gpu_dtype } else { DType::F32 };
1665
1666 let t5_size = std::fs::metadata(&resolved_t5_path)
1667 .map(|m| m.len())
1668 .unwrap_or(0);
1669 let t5_activation_budget = crate::device::activation_bytes(
1672 req.width,
1673 req.height,
1674 1,
1675 crate::device::dtype_bytes(t5_dtype),
1676 crate::device::ActivationFamily::SmallTransformer,
1677 );
1678 preflight_memory_check("T5 encoder", t5_size, t5_activation_budget)?;
1679 if let Some(status) = memory_status_string() {
1680 self.base.progress.info(&status);
1681 }
1682
1683 let t5_stage_label = format!("Loading T5 encoder ({t5_device_label})");
1684 self.base.progress.stage_start(&t5_stage_label);
1685 let t5_stage = Instant::now();
1686 let cached_t5_tok = self.get_cached_tokenizer(&t5_tokenizer_path);
1687 let cached_t5_tensors = self.get_cached_safetensors(&resolved_t5_path)?;
1688 let mut t5 = encoders::t5::T5Encoder::load_with_tokenizer_and_tensors(
1689 &resolved_t5_path,
1690 &t5_tokenizer_path,
1691 t5_device,
1692 t5_dtype,
1693 &self.base.progress,
1694 cached_t5_tok,
1695 cached_t5_tensors,
1696 )?;
1697 self.cache_tokenizer(&t5_tokenizer_path, t5.tokenizer_arc());
1698 self.base
1699 .progress
1700 .stage_done(&t5_stage_label, t5_stage.elapsed());
1701
1702 self.base.progress.stage_start("Encoding prompt (T5)");
1703 let encode_t5 = Instant::now();
1704 let t5_emb = park_cond_to_cpu(&t5.encode(&req.prompt, &device, gpu_dtype)?)?;
1709 self.base
1710 .progress
1711 .stage_done("Encoding prompt (T5)", encode_t5.elapsed());
1712
1713 drop(t5);
1714 self.base.progress.info("Freed T5 encoder");
1715 tracing::info!("T5 encoder dropped (sequential mode)");
1716
1717 let free_for_clip = usable_free_vram_bytes(self.base.gpu_ordinal).unwrap_or(0);
1721 let clip_on_gpu = should_use_gpu(
1722 device.is_cuda(),
1723 device.is_metal(),
1724 free_for_clip,
1725 CLIP_VRAM_THRESHOLD,
1726 );
1727 let clip_ref =
1728 effective_device_ref(self.pending_placement.as_ref(), |adv| adv.clip_l, true);
1729 let auto_clip_device = if clip_on_gpu {
1730 device.clone()
1731 } else {
1732 Device::Cpu
1733 };
1734 let clip_device_owned =
1735 crate::device::resolve_device(Some(clip_ref), || Ok(auto_clip_device.clone()))?;
1736 let clip_device = &clip_device_owned;
1737 let clip_on_gpu = !clip_device.is_cpu();
1738 let clip_dtype = if clip_on_gpu { gpu_dtype } else { DType::F32 };
1739 let clip_device_label = if clip_on_gpu { "GPU" } else { "CPU" };
1740
1741 let clip_stage_label = format!("Loading CLIP encoder ({clip_device_label})");
1742 self.base.progress.stage_start(&clip_stage_label);
1743 let clip_stage = Instant::now();
1744 let cached_clip_tok = self.get_cached_tokenizer(&clip_tokenizer_path);
1745 let cached_clip_tensors = self.get_cached_safetensors(&clip_encoder_path)?;
1746 let clip = encoders::clip::ClipEncoder::load_with_tokenizer_and_tensors(
1747 &clip_encoder_path,
1748 &clip_tokenizer_path,
1749 clip_device,
1750 clip_dtype,
1751 &self.base.progress,
1752 cached_clip_tok,
1753 cached_clip_tensors,
1754 )?;
1755 self.cache_tokenizer(&clip_tokenizer_path, clip.tokenizer_arc());
1756 self.base
1757 .progress
1758 .stage_done(&clip_stage_label, clip_stage.elapsed());
1759
1760 self.base.progress.stage_start("Encoding prompt (CLIP)");
1761 let encode_clip = Instant::now();
1762 let clip_emb = {
1766 let mut clip = clip;
1767 park_cond_to_cpu(&clip.encode(&req.prompt, &device, gpu_dtype)?)?
1768 };
1769 self.base
1770 .progress
1771 .stage_done("Encoding prompt (CLIP)", encode_clip.elapsed());
1772
1773 self.base.progress.info("Freed CLIP encoder");
1774 tracing::info!("CLIP encoder dropped (sequential mode)");
1775
1776 Self::store_prompt_cache(&self.prompt_cache, &req.prompt, &t5_emb, &clip_emb)?;
1780 (t5_emb, clip_emb)
1781 };
1782
1783 device.synchronize()?;
1786
1787 let xformer_size = std::fs::metadata(&transformer_path)
1789 .map(|m| m.len())
1790 .unwrap_or(0);
1791 let vae_file_size = std::fs::metadata(&self.base.paths.vae)
1792 .map(|m| m.len())
1793 .unwrap_or(0);
1794
1795 let activation_budget = crate::device::activation_bytes(
1805 req.width,
1806 req.height,
1807 1, crate::device::dtype_bytes(gpu_dtype),
1809 crate::device::ActivationFamily::FluxDit,
1810 );
1811
1812 let use_offload = if !is_quantized {
1814 let free = usable_free_vram_bytes(self.base.gpu_ordinal).unwrap_or(0);
1818 if self.offload || should_offload(xformer_size, free, activation_budget) {
1819 if free > 0 && free < MIN_OFFLOAD_VRAM {
1820 bail!(
1821 "GPU only has {:.1} GB free — at least {:.1} GB is required \
1822 for block-level offloading",
1823 free as f64 / 1e9,
1824 MIN_OFFLOAD_VRAM as f64 / 1e9,
1825 );
1826 }
1827 true
1828 } else if free > 0 && xformer_size > free {
1829 bail!(
1830 "transformer ({:.1} GB) exceeds available VRAM ({:.1} GB) — \
1831 use a quantized model (q8/q4) or --offload for block-level streaming",
1832 xformer_size as f64 / 1e9,
1833 free as f64 / 1e9,
1834 );
1835 } else {
1836 false
1837 }
1838 } else {
1839 if self.offload {
1840 tracing::warn!(
1841 "block-level offloading is not supported for quantized models; \
1842 --offload / MOLD_OFFLOAD=1 will be ignored"
1843 );
1844 }
1845 false
1846 };
1847
1848 if !use_offload || device.is_metal() {
1851 preflight_memory_check(
1852 "FLUX transformer + VAE",
1853 xformer_size + vae_file_size,
1854 activation_budget,
1855 )?;
1856 }
1857 if let Some(status) = memory_status_string() {
1858 self.base.progress.info(&status);
1859 }
1860
1861 let flux_cfg = if is_schnell {
1862 flux::model::Config::schnell()
1863 } else {
1864 flux::model::Config::dev()
1865 };
1866
1867 let active_loras = effective_loras(req);
1868 let has_lora = !active_loras.is_empty();
1869 let xformer_label = if has_lora && use_offload {
1870 "Loading FLUX transformer + LoRA (offloaded)"
1871 } else if has_lora && is_quantized {
1872 "Loading FLUX transformer + LoRA (GPU, quantized + selective deq)"
1873 } else if has_lora {
1874 "Loading FLUX transformer + LoRA (GPU, BF16)"
1875 } else if use_offload {
1876 "Loading FLUX transformer (offloaded, blocks on CPU)"
1877 } else if is_quantized {
1878 "Loading FLUX transformer (GPU, quantized)"
1879 } else {
1880 "Loading FLUX transformer (GPU, BF16)"
1881 };
1882 self.base.progress.stage_start(xformer_label);
1883 let xformer_stage = Instant::now();
1884
1885 let bypass_mode = LoraBypassMode::from_env();
1886 let use_offload_bypass = use_offload && has_lora && bypass_mode != LoraBypassMode::Off;
1892
1893 let flux_model = if use_offload {
1894 let cpu_vb: VarBuilder = if has_lora && !use_offload_bypass {
1899 flux_lora_var_builder(
1901 &transformer_path,
1902 &active_loras,
1903 gpu_dtype,
1904 &Device::Cpu,
1905 &self.base.progress,
1906 self.lora_delta_cache_handle(),
1907 )?
1908 } else {
1909 flux_transformer_var_builder(flux_safetensors_var_builder(
1910 &transformer_path,
1911 gpu_dtype,
1912 &Device::Cpu,
1913 "FLUX transformer",
1914 &self.base.progress,
1915 )?)
1916 };
1917 let mut offloaded = crate::flux::offload::OffloadedFluxTransformer::load(
1918 cpu_vb,
1919 &flux_cfg,
1920 &device,
1921 &self.base.progress,
1922 )?;
1923 if use_offload_bypass {
1924 let registry = build_lora_registry(
1925 &active_loras,
1926 &flux_cfg,
1927 &device,
1928 gpu_dtype,
1929 &self.base.progress,
1930 )?;
1931 offloaded.set_lora_registry(registry);
1932 }
1933 FluxTransformer::Offloaded(offloaded)
1934 } else if is_quantized && has_lora {
1935 let bypass_quantized = bypass_mode != LoraBypassMode::Off;
1940 if bypass_quantized {
1941 let registry = build_lora_registry(
1942 &active_loras,
1943 &flux_cfg,
1944 &device,
1945 gpu_dtype,
1946 &self.base.progress,
1947 )?;
1948 let vb = quantized_var_builder::VarBuilder::from_gguf(&transformer_path, &device)?;
1949 FluxTransformer::QuantizedBypass(
1950 crate::flux::quantized_transformer::QuantizedFluxTransformer::load(
1951 &flux_cfg,
1952 vb,
1953 registry.as_ref(),
1954 &self.base.progress,
1955 )?,
1956 )
1957 } else {
1958 let vb = flux_gguf_lora_var_builder(
1960 &transformer_path,
1961 &active_loras,
1962 &device,
1963 &self.base.progress,
1964 self.lora_delta_cache_handle(),
1965 )?;
1966 FluxTransformer::Quantized(flux::quantized_model::Flux::new(&flux_cfg, vb)?)
1967 }
1968 } else if is_quantized {
1969 let vb = quantized_var_builder::VarBuilder::from_gguf(&transformer_path, &device)?;
1970 FluxTransformer::Quantized(flux::quantized_model::Flux::new(&flux_cfg, vb)?)
1971 } else if has_lora {
1972 let flux_vb = flux_lora_var_builder(
1974 &transformer_path,
1975 &active_loras,
1976 gpu_dtype,
1977 &device,
1978 &self.base.progress,
1979 self.lora_delta_cache_handle(),
1980 )?;
1981 FluxTransformer::BF16(flux::model::Flux::new(&flux_cfg, flux_vb)?)
1982 } else {
1983 let flux_vb = flux_transformer_var_builder(flux_safetensors_var_builder(
1984 &transformer_path,
1985 gpu_dtype,
1986 &device,
1987 "FLUX transformer",
1988 &self.base.progress,
1989 )?);
1990 FluxTransformer::BF16(flux::model::Flux::new(&flux_cfg, flux_vb)?)
1991 };
1992 self.base
1993 .progress
1994 .stage_done(xformer_label, xformer_stage.elapsed());
1995 if let Some(status) = memory_status_string() {
1996 self.base.progress.info(&status);
1997 }
1998
1999 let noise_dtype = if is_quantized { DType::F32 } else { gpu_dtype };
2001 let latent_h = height / 16 * 2;
2002 let latent_w = width / 16 * 2;
2003 let image_seq_len = (latent_h / 2) * (latent_w / 2);
2007 let mut timesteps = if is_schnell {
2008 flux::sampling::get_schedule(req.steps as usize, None)
2009 } else {
2010 flux::sampling::get_schedule(req.steps as usize, Some((image_seq_len, 0.5, 1.15)))
2011 };
2012
2013 if req.source_image.is_some() {
2014 let start_index = crate::img2img::img2img_start_index(req.steps as usize, req.strength);
2015 timesteps = timesteps[start_index..].to_vec();
2016 tracing::info!(
2017 strength = req.strength,
2018 start_index,
2019 start_timestep = timesteps[0],
2020 schedule = ?timesteps,
2021 remaining_steps = timesteps.len().saturating_sub(1),
2022 "img2img: truncated schedule from strength"
2023 );
2024 }
2025
2026 let vae_cfg = if is_schnell {
2030 flux::autoencoder::Config::schnell()
2031 } else {
2032 flux::autoencoder::Config::dev()
2033 };
2034 let early_vae_dtype = crate::device::resolve_vae_dtype(gpu_dtype);
2038
2039 let (img, inpaint_ctx, early_vae) = if let Some(ref source_bytes) = req.source_image {
2040 let start_t = timesteps[0];
2041
2042 self.base.progress.stage_start("Loading VAE (GPU)");
2044 let vae_stage = Instant::now();
2045 let vae_vb = self.load_vae_var_builder(early_vae_dtype, &device, "VAE")?;
2046 let vae = flux::autoencoder::AutoEncoder::new(&vae_cfg, vae_vb)?;
2047 self.base
2048 .progress
2049 .stage_done("Loading VAE (GPU)", vae_stage.elapsed());
2050
2051 self.base
2052 .progress
2053 .stage_start("Encoding source image (VAE)");
2054 let encode_start = Instant::now();
2055 let source_tensor = crate::img_utils::decode_source_image(
2056 source_bytes,
2057 req.width,
2058 req.height,
2059 crate::img_utils::NormalizeRange::MinusOneToOne,
2060 &device,
2061 early_vae_dtype,
2062 )?;
2063 let encoded = vae.encode(&source_tensor)?;
2065 self.base
2066 .progress
2067 .stage_done("Encoding source image (VAE)", encode_start.elapsed());
2068
2069 let noise = crate::engine::seeded_randn(
2072 seed,
2073 &[1, 16, latent_h, latent_w],
2074 &device,
2075 noise_dtype,
2076 )?;
2077 let encoded = encoded.to_dtype(noise_dtype)?;
2078
2079 let inpaint_ctx = if let Some(ref mask_bytes) = req.mask_image {
2081 let mask = crate::img_utils::decode_mask_image(
2082 mask_bytes,
2083 latent_h,
2084 latent_w,
2085 &device,
2086 noise_dtype,
2087 )?;
2088 Some(crate::img_utils::InpaintContext {
2089 original_latents: encoded.clone(),
2090 mask,
2091 noise: noise.clone(),
2092 })
2093 } else {
2094 None
2095 };
2096
2097 let img = ((&encoded * (1.0 - start_t))? + (&noise * start_t)?)?;
2100 (img, inpaint_ctx, Some(vae))
2101 } else {
2102 let img = crate::engine::seeded_randn(
2103 seed,
2104 &[1, 16, latent_h, latent_w],
2105 &device,
2106 noise_dtype,
2107 )?;
2108 (img, None, None)
2109 };
2110
2111 let t5_emb = t5_emb.to_device(&device)?;
2116 let clip_emb = clip_emb.to_device(&device)?;
2117 let (t5_emb_state, clip_emb_state, img_state) = if is_quantized {
2118 (
2119 t5_emb.to_dtype(DType::F32)?,
2120 clip_emb.to_dtype(DType::F32)?,
2121 img.to_dtype(DType::F32)?,
2122 )
2123 } else {
2124 (t5_emb, clip_emb, img)
2125 };
2126
2127 let state = flux::sampling::State::new(&t5_emb_state, &clip_emb_state, &img_state)?;
2128 let inpaint_ctx = inpaint_ctx
2129 .as_ref()
2130 .map(crate::img2img::pack_flux_inpaint_context)
2131 .transpose()?;
2132
2133 let denoise_label = format!("Denoising ({} steps)", timesteps.len().saturating_sub(1));
2134 self.base.progress.stage_start(&denoise_label);
2135 let denoise_start = Instant::now();
2136
2137 let img = flux_model.denoise(
2138 &state.img,
2139 &state.img_ids,
2140 &state.txt,
2141 &state.txt_ids,
2142 &state.vec,
2143 ×teps,
2144 req.guidance,
2145 &self.base.progress,
2146 inpaint_ctx.as_ref(),
2147 )?;
2148
2149 let img = flux::sampling::unpack(&img, height, width)?;
2150 self.base
2151 .progress
2152 .stage_done(&denoise_label, denoise_start.elapsed());
2153
2154 drop(inpaint_ctx);
2156 drop(flux_model);
2157 self.base.progress.info("Freed FLUX transformer");
2158 drop(state);
2159 drop(t5_emb_state);
2160 drop(clip_emb_state);
2161 drop(img_state);
2162 device.synchronize()?;
2164 tracing::info!("Transformer dropped (sequential mode), decoding VAE...");
2165
2166 let vae_dtype = crate::device::resolve_vae_dtype(gpu_dtype);
2172 let vae = if let Some(vae) = early_vae {
2173 vae
2174 } else {
2175 self.base.progress.stage_start("Loading VAE (GPU)");
2176 let vae_stage = Instant::now();
2177 let vae_vb = self.load_vae_var_builder(vae_dtype, &device, "VAE")?;
2178 let vae = flux::autoencoder::AutoEncoder::new(&vae_cfg, vae_vb)?;
2179 self.base
2180 .progress
2181 .stage_done("Loading VAE (GPU)", vae_stage.elapsed());
2182 vae
2183 };
2184 self.base.progress.stage_start("VAE decode");
2185 let vae_decode_start = Instant::now();
2186 let img_for_vae = img.to_dtype(vae_dtype)?;
2187 let device_for_sync = device.clone();
2188 let img = crate::vae_tiling::decode_with_oom_fallback(
2189 &img_for_vae,
2190 |latents| vae.decode(latents).map_err(Into::into),
2191 || {
2192 if let Err(e) = device_for_sync.synchronize() {
2193 tracing::warn!(
2194 "FLUX (sequential) device.synchronize() after VAE OOM failed: {e}"
2195 );
2196 }
2197 },
2198 )?;
2199
2200 let img = ((img.clamp(-1f32, 1f32)? + 1.0)? * 127.5)?.to_dtype(DType::U8)?;
2201 let img = img.i(0)?;
2202
2203 self.base
2204 .progress
2205 .stage_done("VAE decode", vae_decode_start.elapsed());
2206 let output_metadata = build_output_metadata(req, seed, None);
2209 let image_bytes = encode_image(
2210 &img,
2211 req.resolved_output_format(),
2212 req.width,
2213 req.height,
2214 output_metadata.as_ref(),
2215 )?;
2216
2217 let generation_time_ms = start.elapsed().as_millis() as u64;
2218 tracing::info!(generation_time_ms, seed, "sequential generation complete");
2219
2220 Ok(GenerateResponse {
2221 images: vec![ImageData {
2222 data: image_bytes,
2223 format: req.resolved_output_format(),
2224 width: req.width,
2225 height: req.height,
2226 index: 0,
2227 }],
2228 generation_time_ms,
2229 model: req.model.clone(),
2230 seed_used: seed,
2231 video: None,
2232 gpu: None,
2233 })
2234 }
2235}
2236
2237impl FluxEngine {
2238 fn defers_eager_load(&mut self) -> bool {
2239 self.base.load_strategy == LoadStrategy::Sequential
2240 || (self.offload && !self.detect_is_quantized())
2241 }
2242
2243 fn uses_sequential_generate_path(&mut self) -> bool {
2244 self.defers_eager_load()
2245 }
2246
2247 fn generate_inner(&mut self, req: &GenerateRequest) -> Result<GenerateResponse> {
2248 if req.scheduler.is_some() {
2249 tracing::warn!("scheduler selection not supported for FLUX (flow-matching), ignoring");
2250 }
2251
2252 if self.uses_sequential_generate_path() {
2257 return self.generate_sequential(req);
2258 }
2259
2260 let progress = &self.base.progress;
2266 let prompt_cache = &self.prompt_cache;
2267
2268 let t5_encoder_path = self
2270 .base
2271 .loaded
2272 .as_ref()
2273 .map(|l| l.t5_encoder_path.clone())
2274 .or_else(|| self.base.paths.t5_encoder.clone())
2275 .ok_or_else(|| anyhow::anyhow!("T5 encoder path required for FLUX models"))?;
2276 let clip_encoder_path = self
2277 .base
2278 .paths
2279 .clip_encoder
2280 .clone()
2281 .ok_or_else(|| anyhow::anyhow!("CLIP encoder path required for FLUX models"))?;
2282 let transformer_path = self
2283 .base
2284 .loaded
2285 .as_ref()
2286 .map(|l| l.transformer_path.clone())
2287 .unwrap_or_else(|| self.base.paths.transformer.clone());
2288
2289 let cache_handle = self.lora_delta_cache_handle();
2293
2294 let mut loaded = OptionRestoreGuard::take(&mut self.base.loaded)
2295 .ok_or_else(|| anyhow::anyhow!("model not loaded — call load() first"))?;
2296
2297 let start = Instant::now();
2298 let seed = req.seed.unwrap_or_else(rand_seed);
2299
2300 let width = req.width as usize;
2301 let height = req.height as usize;
2302 let loaded_dtype = loaded.dtype;
2303 let loaded_device = loaded.device.clone();
2304
2305 tracing::info!(
2306 prompt = %req.prompt,
2307 seed,
2308 width,
2309 height,
2310 steps = req.steps,
2311 "starting generation"
2312 );
2313
2314 (|| -> Result<GenerateResponse> {
2315 let active_loras = effective_loras(req);
2318 let requested_stack = fingerprint_stack(&active_loras);
2319 if requested_stack != self.active_lora {
2320 if loaded.flux_model.is_some() {
2321 loaded.flux_model = None;
2322 loaded.device.synchronize()?;
2323 }
2324 self.active_lora = requested_stack;
2325 }
2326
2327 if loaded.flux_model.is_none() {
2328 let has_lora = !active_loras.is_empty();
2329 let xformer_label = match (loaded.is_quantized, has_lora) {
2330 (true, true) => "Reloading FLUX transformer (GPU, quantized + LoRA)",
2331 (true, false) => "Reloading FLUX transformer (GPU, quantized)",
2332 (false, true) if loaded.dtype == DType::F16 => {
2333 "Reloading FLUX transformer (GPU, FP16 + LoRA)"
2334 }
2335 (false, true) => "Reloading FLUX transformer (GPU, BF16 + LoRA)",
2336 (false, false) if loaded.dtype == DType::F16 => {
2337 "Reloading FLUX transformer (GPU, FP16)"
2338 }
2339 (false, false) => "Reloading FLUX transformer (GPU, BF16)",
2340 };
2341 progress.stage_start(xformer_label);
2342 let reload_start = Instant::now();
2343 let flux_cfg = if loaded.is_schnell {
2344 flux::model::Config::schnell()
2345 } else {
2346 flux::model::Config::dev()
2347 };
2348 let bypass_mode = LoraBypassMode::from_env();
2349 loaded.flux_model = Some(if loaded.is_quantized && has_lora {
2350 let bypass_quantized = bypass_mode != LoraBypassMode::Off;
2356 if bypass_quantized {
2357 let registry = build_lora_registry(
2358 &active_loras,
2359 &flux_cfg,
2360 &loaded.device,
2361 loaded.dtype,
2362 progress,
2363 )?;
2364 let vb = quantized_var_builder::VarBuilder::from_gguf(
2365 &transformer_path,
2366 &loaded.device,
2367 )?;
2368 FluxTransformer::QuantizedBypass(
2369 crate::flux::quantized_transformer::QuantizedFluxTransformer::load(
2370 &flux_cfg,
2371 vb,
2372 registry.as_ref(),
2373 progress,
2374 )?,
2375 )
2376 } else {
2377 let vb = flux_gguf_lora_var_builder(
2378 &transformer_path,
2379 &active_loras,
2380 &loaded.device,
2381 progress,
2382 cache_handle.clone(),
2383 )?;
2384 FluxTransformer::Quantized(flux::quantized_model::Flux::new(&flux_cfg, vb)?)
2385 }
2386 } else if loaded.is_quantized {
2387 let vb = quantized_var_builder::VarBuilder::from_gguf(
2388 &transformer_path,
2389 &loaded.device,
2390 )?;
2391 FluxTransformer::Quantized(flux::quantized_model::Flux::new(&flux_cfg, vb)?)
2392 } else if has_lora {
2393 let flux_vb = flux_lora_var_builder(
2395 &transformer_path,
2396 &active_loras,
2397 loaded.dtype,
2398 &loaded.device,
2399 progress,
2400 cache_handle.clone(),
2401 )?;
2402 FluxTransformer::BF16(flux::model::Flux::new(&flux_cfg, flux_vb)?)
2403 } else {
2404 let flux_vb = flux_transformer_var_builder(flux_safetensors_var_builder(
2405 &transformer_path,
2406 loaded.dtype,
2407 &loaded.device,
2408 "FLUX transformer",
2409 progress,
2410 )?);
2411 FluxTransformer::BF16(flux::model::Flux::new(&flux_cfg, flux_vb)?)
2412 });
2413 progress.stage_done(xformer_label, reload_start.elapsed());
2414 }
2415
2416 if let Some((t5_emb, clip_emb)) = Self::restore_prompt_cache(
2417 progress,
2418 prompt_cache,
2419 &req.prompt,
2420 &loaded_device,
2421 loaded_dtype,
2422 )? {
2423 return Self::generate_with_embeddings(
2424 progress,
2425 req,
2426 &mut loaded,
2427 t5_emb,
2428 clip_emb,
2429 seed,
2430 width,
2431 height,
2432 start,
2433 self.base.gpu_ordinal,
2434 );
2435 }
2436
2437 if loaded.t5.model.is_none() {
2438 let label = if loaded.t5.is_parked() {
2439 "Unparking T5 encoder (CPU→GPU)"
2440 } else {
2441 "Reloading T5 encoder (GPU)"
2442 };
2443 progress.stage_start(label);
2444 let reload_start = Instant::now();
2445 if loaded.t5.is_parked() {
2446 loaded.t5.unpark_to_gpu(loaded_dtype, progress)?;
2447 } else {
2448 loaded.t5.reload(&t5_encoder_path, loaded_dtype, progress)?;
2449 }
2450 progress.stage_done(label, reload_start.elapsed());
2451 }
2452 if loaded.clip.model.is_none() {
2453 let label = if loaded.clip.is_parked() {
2454 "Unparking CLIP encoder (CPU→GPU)"
2455 } else {
2456 "Reloading CLIP encoder (GPU)"
2457 };
2458 progress.stage_start(label);
2459 let reload_start = Instant::now();
2460 if loaded.clip.is_parked() {
2461 loaded.clip.unpark_to_gpu(loaded_dtype, progress)?;
2462 } else {
2463 loaded
2464 .clip
2465 .reload(&clip_encoder_path, loaded_dtype, progress)?;
2466 }
2467 progress.stage_done(label, reload_start.elapsed());
2468 }
2469
2470 progress.stage_start("Encoding prompt (T5)");
2471 let encode_t5 = Instant::now();
2472 let t5_emb = park_cond_to_cpu(&loaded.t5.encode(
2476 &req.prompt,
2477 &loaded_device,
2478 loaded_dtype,
2479 )?)?;
2480 progress.stage_done("Encoding prompt (T5)", encode_t5.elapsed());
2481 tracing::info!("T5 encoding complete");
2482
2483 progress.stage_start("Encoding prompt (CLIP)");
2484 let encode_clip = Instant::now();
2485 let clip_emb = park_cond_to_cpu(&loaded.clip.encode(
2486 &req.prompt,
2487 &loaded_device,
2488 loaded_dtype,
2489 )?)?;
2490 progress.stage_done("Encoding prompt (CLIP)", encode_clip.elapsed());
2491 tracing::info!("CLIP encoding complete");
2492 Self::store_prompt_cache(prompt_cache, &req.prompt, &t5_emb, &clip_emb)?;
2495
2496 let is_metal = loaded.device.is_metal();
2509 let park_mode = crate::device::keep_te_in_ram() && !is_metal;
2510 let mut dropped_gpu_encoder = false;
2511 if loaded.t5.on_gpu || is_metal {
2512 if loaded.t5.on_gpu {
2513 dropped_gpu_encoder = true;
2514 }
2515 if park_mode {
2516 loaded.t5.park_to_cpu()?;
2517 tracing::info!(
2518 on_gpu = loaded.t5.on_gpu,
2519 "T5 encoder parked to CPU host RAM"
2520 );
2521 } else {
2522 loaded.t5.drop_weights();
2523 tracing::info!(
2524 on_gpu = loaded.t5.on_gpu,
2525 "T5 encoder dropped to free memory for denoising"
2526 );
2527 }
2528 }
2529 if loaded.clip.on_gpu || is_metal {
2530 if loaded.clip.on_gpu {
2531 dropped_gpu_encoder = true;
2532 }
2533 if park_mode {
2534 loaded.clip.park_to_cpu()?;
2535 tracing::info!(
2536 on_gpu = loaded.clip.on_gpu,
2537 "CLIP encoder parked to CPU host RAM"
2538 );
2539 } else {
2540 loaded.clip.drop_weights();
2541 tracing::info!(
2542 on_gpu = loaded.clip.on_gpu,
2543 "CLIP encoder dropped to free memory for denoising"
2544 );
2545 }
2546 }
2547 if dropped_gpu_encoder {
2553 loaded.device.synchronize()?;
2554 }
2555
2556 Self::generate_with_embeddings(
2557 progress,
2558 req,
2559 &mut loaded,
2560 t5_emb,
2561 clip_emb,
2562 seed,
2563 width,
2564 height,
2565 start,
2566 self.base.gpu_ordinal,
2567 )
2568 })()
2569 }
2570}
2571
2572impl InferenceEngine for FluxEngine {
2573 fn generate(&mut self, req: &GenerateRequest) -> Result<GenerateResponse> {
2574 self.pending_placement = req.placement.clone();
2575 let result = self.generate_inner(req);
2576 self.pending_placement = None;
2577 result
2578 }
2579
2580 fn model_name(&self) -> &str {
2581 self.base.model_name()
2582 }
2583
2584 fn is_loaded(&self) -> bool {
2585 self.base.is_loaded()
2587 }
2588
2589 fn load(&mut self) -> Result<()> {
2590 FluxEngine::load(self)
2591 }
2592
2593 fn unload(&mut self) {
2594 self.base.unload();
2595 clear_cache(&self.prompt_cache);
2598 self.active_lora = Vec::new();
2602 }
2605
2606 fn set_on_progress(&mut self, callback: ProgressCallback) {
2607 self.base.set_on_progress(callback);
2608 }
2609
2610 fn clear_on_progress(&mut self) {
2611 self.base.clear_on_progress();
2612 }
2613
2614 fn model_paths(&self) -> Option<&mold_core::ModelPaths> {
2615 Some(&self.base.paths)
2616 }
2617}
2618
2619impl FluxEngine {
2620 #[allow(clippy::too_many_arguments)]
2621 fn generate_with_embeddings(
2622 progress: &ProgressReporter,
2623 req: &GenerateRequest,
2624 loaded: &mut LoadedFlux,
2625 t5_emb: candle_core::Tensor,
2626 clip_emb: candle_core::Tensor,
2627 seed: u64,
2628 width: usize,
2629 height: usize,
2630 start: Instant,
2631 gpu_ordinal: usize,
2632 ) -> Result<GenerateResponse> {
2633 let noise_dtype = if loaded.is_quantized {
2635 DType::F32
2636 } else {
2637 loaded.dtype
2638 };
2639 let latent_h = height / 16 * 2;
2640 let latent_w = width / 16 * 2;
2641
2642 let image_seq_len = (latent_h / 2) * (latent_w / 2);
2644 let mut timesteps = if loaded.is_schnell {
2645 flux::sampling::get_schedule(req.steps as usize, None)
2646 } else {
2647 flux::sampling::get_schedule(req.steps as usize, Some((image_seq_len, 0.5, 1.15)))
2648 };
2649
2650 if req.source_image.is_some() {
2651 let start_index = crate::img2img::img2img_start_index(req.steps as usize, req.strength);
2652 timesteps = timesteps[start_index..].to_vec();
2653 tracing::info!(
2654 strength = req.strength,
2655 start_index,
2656 start_timestep = timesteps[0],
2657 schedule = ?timesteps,
2658 remaining_steps = timesteps.len().saturating_sub(1),
2659 "img2img: truncated schedule from strength"
2660 );
2661 }
2662
2663 let (img, inpaint_ctx) = if let Some(ref source_bytes) = req.source_image {
2664 let start_t = timesteps[0];
2665
2666 progress.stage_start("Encoding source image (VAE)");
2667 let encode_start = Instant::now();
2668 let source_tensor = crate::img_utils::decode_source_image(
2669 source_bytes,
2670 req.width,
2671 req.height,
2672 crate::img_utils::NormalizeRange::MinusOneToOne,
2673 &loaded.device,
2674 loaded.vae_dtype,
2675 )?;
2676 let encoded = loaded.vae.encode(&source_tensor)?;
2677 progress.stage_done("Encoding source image (VAE)", encode_start.elapsed());
2678
2679 let noise = crate::engine::seeded_randn(
2680 seed,
2681 &[1, 16, latent_h, latent_w],
2682 &loaded.device,
2683 noise_dtype,
2684 )?;
2685 let encoded = encoded.to_dtype(noise_dtype)?;
2686
2687 let inpaint_ctx = if let Some(ref mask_bytes) = req.mask_image {
2688 let mask = crate::img_utils::decode_mask_image(
2689 mask_bytes,
2690 latent_h,
2691 latent_w,
2692 &loaded.device,
2693 noise_dtype,
2694 )?;
2695 Some(crate::img_utils::InpaintContext {
2696 original_latents: encoded.clone(),
2697 mask,
2698 noise: noise.clone(),
2699 })
2700 } else {
2701 None
2702 };
2703
2704 let img = ((&encoded * (1.0 - start_t))? + (&noise * start_t)?)?;
2706 (img, inpaint_ctx)
2707 } else {
2708 let img = crate::engine::seeded_randn(
2709 seed,
2710 &[1, 16, latent_h, latent_w],
2711 &loaded.device,
2712 noise_dtype,
2713 )?;
2714 (img, None)
2715 };
2716
2717 let t5_emb = t5_emb.to_device(&loaded.device)?;
2722 let clip_emb = clip_emb.to_device(&loaded.device)?;
2723 let (t5_emb_state, clip_emb_state, img_state) = if loaded.is_quantized {
2725 (
2726 t5_emb.to_dtype(DType::F32)?,
2727 clip_emb.to_dtype(DType::F32)?,
2728 img.to_dtype(DType::F32)?,
2729 )
2730 } else {
2731 (t5_emb, clip_emb, img)
2732 };
2733
2734 let state = flux::sampling::State::new(&t5_emb_state, &clip_emb_state, &img_state)?;
2736 let inpaint_ctx = inpaint_ctx
2737 .as_ref()
2738 .map(crate::img2img::pack_flux_inpaint_context)
2739 .transpose()?;
2740
2741 let denoise_label = format!("Denoising ({} steps)", timesteps.len().saturating_sub(1));
2742 progress.stage_start(&denoise_label);
2743 let denoise_start = Instant::now();
2744 tracing::info!(
2745 steps = timesteps.len().saturating_sub(1),
2746 quantized = loaded.is_quantized,
2747 "running denoising loop..."
2748 );
2749
2750 let img = loaded
2752 .flux_model
2753 .as_ref()
2754 .ok_or_else(|| anyhow::anyhow!("transformer not loaded"))?
2755 .denoise(
2756 &state.img,
2757 &state.img_ids,
2758 &state.txt,
2759 &state.txt_ids,
2760 &state.vec,
2761 ×teps,
2762 req.guidance,
2763 progress,
2764 inpaint_ctx.as_ref(),
2765 )?;
2766
2767 let img = flux::sampling::unpack(&img, height, width)?;
2769 progress.stage_done(&denoise_label, denoise_start.elapsed());
2770 tracing::info!("denoising complete, decoding VAE...");
2771
2772 drop(state);
2780 drop(t5_emb_state);
2781 drop(clip_emb_state);
2782 drop(img_state);
2783 let keep_transformer_env = std::env::var("MOLD_FLUX_KEEP_TRANSFORMER")
2784 .map(|v| v == "1")
2785 .unwrap_or(false);
2786
2787 let vae_headroom_bytes = crate::device::activation_bytes(
2802 req.width,
2803 req.height,
2804 1,
2805 crate::device::dtype_bytes(loaded.dtype),
2806 crate::device::ActivationFamily::FluxDit,
2807 );
2808 let free_before_vae = crate::device::free_vram_bytes(gpu_ordinal).unwrap_or(0);
2809 let force_drop_for_headroom =
2810 keep_transformer_env && free_before_vae > 0 && free_before_vae < vae_headroom_bytes;
2811
2812 if !keep_transformer_env || force_drop_for_headroom {
2813 loaded.flux_model = None;
2814 if force_drop_for_headroom {
2815 tracing::info!(
2816 free_mb = free_before_vae / 1024 / 1024,
2817 headroom_mb = vae_headroom_bytes / 1024 / 1024,
2818 "Transformer force-dropped before VAE decode (free VRAM below \
2819 resolution-scaled headroom; overrides MOLD_FLUX_KEEP_TRANSFORMER=1 \
2820 for this request)"
2821 );
2822 } else {
2823 tracing::info!("Transformer dropped to free VRAM for VAE decode");
2824 }
2825 } else {
2826 tracing::info!(
2827 free_mb = free_before_vae / 1024 / 1024,
2828 "Transformer kept loaded (MOLD_FLUX_KEEP_TRANSFORMER=1)"
2829 );
2830 }
2831 loaded.device.synchronize()?;
2840
2841 progress.stage_start("VAE decode");
2846 let vae_decode_start = Instant::now();
2847 let img_for_vae = img.to_dtype(loaded.vae_dtype)?;
2848 let vae = &loaded.vae;
2849 let device_for_sync = loaded.device.clone();
2850 let img = crate::vae_tiling::decode_with_oom_fallback(
2851 &img_for_vae,
2852 |latents| vae.decode(latents).map_err(Into::into),
2853 || {
2854 if let Err(e) = device_for_sync.synchronize() {
2855 tracing::warn!(
2856 "FLUX (parallel) device.synchronize() after VAE OOM failed: {e}"
2857 );
2858 }
2859 },
2860 )?;
2861
2862 let img = ((img.clamp(-1f32, 1f32)? + 1.0)? * 127.5)?.to_dtype(DType::U8)?;
2864 let img = img.i(0)?; progress.stage_done("VAE decode", vae_decode_start.elapsed());
2867 tracing::info!("VAE decode complete, encoding output image...");
2868
2869 let output_metadata = build_output_metadata(req, seed, None);
2871 let image_bytes = encode_image(
2872 &img,
2873 req.resolved_output_format(),
2874 req.width,
2875 req.height,
2876 output_metadata.as_ref(),
2877 )?;
2878
2879 let generation_time_ms = start.elapsed().as_millis() as u64;
2880 tracing::info!(generation_time_ms, seed, "generation complete");
2881
2882 Ok(GenerateResponse {
2883 images: vec![ImageData {
2884 data: image_bytes,
2885 format: req.resolved_output_format(),
2886 width: req.width,
2887 height: req.height,
2888 index: 0,
2889 }],
2890 generation_time_ms,
2891 model: req.model.clone(),
2892 seed_used: seed,
2893 video: None,
2894 gpu: None,
2895 })
2896 }
2897}
2898
2899#[cfg(test)]
2900mod tests {
2901 use super::{
2902 effective_loras, flux_rms_norm_scale_aliases, flux_runtime_dtype,
2903 flux_transformer_var_builder, park_cond_to_cpu, LoraBypassMode,
2904 };
2905 use crate::LoadStrategy;
2906 use candle_core::{DType, Device, Result, Tensor};
2907 use candle_nn::VarBuilder;
2908 use mold_core::{GenerateRequest, LoraWeight, ModelPaths, OutputFormat};
2909 use std::collections::HashMap;
2910 use std::path::PathBuf;
2911
2912 #[test]
2917 fn lora_bypass_mode_env_parsing() {
2918 let with_env = |val: Option<&str>| -> LoraBypassMode {
2919 unsafe {
2920 match val {
2921 Some(v) => std::env::set_var("MOLD_LORA_BYPASS", v),
2922 None => std::env::remove_var("MOLD_LORA_BYPASS"),
2923 }
2924 }
2925 let mode = LoraBypassMode::from_env();
2926 unsafe {
2927 std::env::remove_var("MOLD_LORA_BYPASS");
2928 }
2929 mode
2930 };
2931 assert_eq!(with_env(Some("on")), LoraBypassMode::On);
2932 assert_eq!(with_env(Some("ON")), LoraBypassMode::On);
2933 assert_eq!(with_env(Some("1")), LoraBypassMode::On);
2934 assert_eq!(with_env(Some("off")), LoraBypassMode::Off);
2935 assert_eq!(with_env(Some("0")), LoraBypassMode::Off);
2936 assert_eq!(with_env(Some("auto")), LoraBypassMode::Auto);
2937 assert_eq!(with_env(Some("garbage")), LoraBypassMode::Auto);
2938 assert_eq!(with_env(None), LoraBypassMode::Auto);
2939 }
2940
2941 #[test]
2942 fn flux_rms_norm_aliases_detect_weight_suffix_checkpoint() {
2943 use safetensors::tensor::{serialize_to_file, Dtype as SafeDtype, TensorView};
2944
2945 let dir = std::env::temp_dir().join(format!(
2946 "mold-flux-rms-alias-{}-{}",
2947 std::process::id(),
2948 std::time::SystemTime::now()
2949 .duration_since(std::time::UNIX_EPOCH)
2950 .unwrap()
2951 .as_nanos()
2952 ));
2953 std::fs::create_dir_all(&dir).unwrap();
2954 let path = dir.join("flux-rms-weight.safetensors");
2955
2956 let data = 1.0f32.to_le_bytes();
2957 let mut tensors = HashMap::new();
2958 tensors.insert(
2959 "model.diffusion_model.double_blocks.0.img_attn.norm.query_norm.weight".to_string(),
2960 TensorView::new(SafeDtype::F32, vec![1], &data).unwrap(),
2961 );
2962 serialize_to_file(&tensors, &None, &path).unwrap();
2963
2964 let aliases = flux_rms_norm_scale_aliases(&path).unwrap();
2965 assert_eq!(
2966 aliases.get("model.diffusion_model.double_blocks.0.img_attn.norm.query_norm.scale"),
2967 Some(
2968 &"model.diffusion_model.double_blocks.0.img_attn.norm.query_norm.weight"
2969 .to_string()
2970 )
2971 );
2972
2973 std::fs::remove_dir_all(&dir).ok();
2974 }
2975
2976 fn dummy_paths(transformer: &str) -> ModelPaths {
2977 ModelPaths {
2978 transformer: PathBuf::from(transformer),
2979 transformer_shards: Vec::new(),
2980 vae: PathBuf::from("ae.safetensors"),
2981 spatial_upscaler: None,
2982 temporal_upscaler: None,
2983 distilled_lora: None,
2984 t5_encoder: Some(PathBuf::from("t5.safetensors")),
2985 clip_encoder: Some(PathBuf::from("clip.safetensors")),
2986 t5_tokenizer: Some(PathBuf::from("t5-tokenizer.json")),
2987 clip_tokenizer: Some(PathBuf::from("clip-tokenizer.json")),
2988 clip_encoder_2: None,
2989 clip_tokenizer_2: None,
2990 text_encoder_files: Vec::new(),
2991 text_tokenizer: None,
2992 decoder: None,
2993 }
2994 }
2995
2996 #[test]
2997 fn forced_offload_uses_sequential_generation_path_for_bf16_flux() {
2998 let mut engine = super::FluxEngine::new(
2999 "flux-dev:bf16".to_string(),
3000 dummy_paths("flux1-dev.safetensors"),
3001 Some(false),
3002 None,
3003 LoadStrategy::Eager,
3004 0,
3005 true,
3006 None,
3007 );
3008
3009 assert!(engine.uses_sequential_generate_path());
3010 }
3011
3012 #[test]
3013 fn forced_offload_defers_eager_load_for_bf16_flux() {
3014 let mut engine = super::FluxEngine::new(
3015 "flux-dev:bf16".to_string(),
3016 dummy_paths("flux1-dev.safetensors"),
3017 Some(false),
3018 None,
3019 LoadStrategy::Eager,
3020 0,
3021 true,
3022 None,
3023 );
3024
3025 assert!(engine.defers_eager_load());
3026 }
3027
3028 fn req_with_loras(
3033 single: Option<LoraWeight>,
3034 plural: Option<Vec<LoraWeight>>,
3035 ) -> GenerateRequest {
3036 GenerateRequest {
3037 prompt: String::new(),
3038 negative_prompt: None,
3039 model: "flux-dev".to_string(),
3040 width: 1024,
3041 height: 1024,
3042 steps: 4,
3043 guidance: 0.0,
3044 seed: None,
3045 batch_size: 1,
3046 output_format: Some(OutputFormat::Png),
3047 embed_metadata: None,
3048 scheduler: None,
3049 cfg_plus: None,
3050 source_image: None,
3051 edit_images: None,
3052 strength: 0.75,
3053 mask_image: None,
3054 control_image: None,
3055 control_model: None,
3056 control_scale: 1.0,
3057 expand: None,
3058 original_prompt: None,
3059 lora: single,
3060 frames: None,
3061 fps: None,
3062 upscale_model: None,
3063 gif_preview: false,
3064 enable_audio: None,
3065 audio_file: None,
3066 audio_file_path: None,
3067 source_video: None,
3068 source_video_path: None,
3069 keyframes: None,
3070 pipeline: None,
3071 loras: plural,
3072 retake_range: None,
3073 spatial_upscale: None,
3074 temporal_upscale: None,
3075 placement: None,
3076 }
3077 }
3078
3079 #[test]
3082 fn effective_loras_drops_zero_scale() {
3083 let req = req_with_loras(
3084 None,
3085 Some(vec![
3086 LoraWeight {
3087 path: "p1".into(),
3088 scale: 0.8,
3089 },
3090 LoraWeight {
3091 path: "p2".into(),
3092 scale: 0.0,
3093 },
3094 LoraWeight {
3095 path: "p3".into(),
3096 scale: 0.5,
3097 },
3098 ]),
3099 );
3100 let stack = effective_loras(&req);
3101 let paths: Vec<&str> = stack.iter().map(|w| w.path.as_str()).collect();
3102 assert_eq!(
3103 paths,
3104 vec!["p1", "p3"],
3105 "p2 (scale=0.0) must be dropped from the effective stack"
3106 );
3107 assert!((stack[0].scale - 0.8).abs() < 1e-9);
3108 assert!((stack[1].scale - 0.5).abs() < 1e-9);
3109 }
3110
3111 #[test]
3113 fn effective_loras_keeps_negative_scales() {
3114 let req = req_with_loras(
3115 None,
3116 Some(vec![LoraWeight {
3117 path: "p1".into(),
3118 scale: -0.3,
3119 }]),
3120 );
3121 let stack = effective_loras(&req);
3122 assert_eq!(stack.len(), 1);
3123 assert!((stack[0].scale - (-0.3)).abs() < 1e-9);
3124 }
3125
3126 #[test]
3128 fn effective_loras_drops_zero_scale_on_single_form() {
3129 let req = req_with_loras(
3130 Some(LoraWeight {
3131 path: "p1".into(),
3132 scale: 0.0,
3133 }),
3134 None,
3135 );
3136 assert!(effective_loras(&req).is_empty());
3137 }
3138
3139 #[test]
3143 fn park_cond_to_cpu_is_idempotent_for_cpu_tensors() {
3144 let cpu_tensor = Tensor::zeros((2, 4), DType::F32, &Device::Cpu).unwrap();
3145 let parked = park_cond_to_cpu(&cpu_tensor).unwrap();
3146 assert!(parked.device().is_cpu(), "CPU input must stay on CPU");
3147 assert_eq!(parked.shape(), cpu_tensor.shape());
3148 }
3149
3150 #[test]
3152 fn park_cond_to_cpu_returns_cpu_tensor_for_any_input() {
3153 let input = Tensor::ones((1, 3), DType::F32, &Device::Cpu).unwrap();
3154 let parked = park_cond_to_cpu(&input).unwrap();
3155 assert!(parked.device().is_cpu(), "output must be on CPU");
3156 assert_eq!(parked.shape(), input.shape());
3157 assert_eq!(parked.dtype(), input.dtype());
3158 }
3159
3160 #[test]
3161 fn flux_var_builder_uses_root_tensors_when_present() -> Result<()> {
3162 let tensors = HashMap::from([(
3163 "img_in.weight".to_string(),
3164 Tensor::zeros((1, 1), DType::F32, &Device::Cpu)?,
3165 )]);
3166 let vb = VarBuilder::from_tensors(tensors, DType::F32, &Device::Cpu);
3167 let resolved = flux_transformer_var_builder(vb);
3168
3169 assert!(resolved.contains_tensor("img_in.weight"));
3170 assert_eq!(resolved.prefix(), "");
3171 Ok(())
3172 }
3173
3174 #[test]
3175 fn flux_var_builder_uses_model_diffusion_model_prefix_when_present() -> Result<()> {
3176 let tensors = HashMap::from([(
3177 "model.diffusion_model.img_in.weight".to_string(),
3178 Tensor::zeros((1, 1), DType::F32, &Device::Cpu)?,
3179 )]);
3180 let vb = VarBuilder::from_tensors(tensors, DType::F32, &Device::Cpu);
3181 let resolved = flux_transformer_var_builder(vb);
3182
3183 assert!(resolved.contains_tensor("img_in.weight"));
3184 assert_eq!(resolved.prefix(), "model.diffusion_model");
3185 Ok(())
3186 }
3187
3188 #[test]
3189 fn flux_runtime_dtype_prefers_f16_for_cuda_fp8_safetensors() {
3190 assert_eq!(flux_runtime_dtype(true, false, true), DType::F16);
3191 assert_eq!(flux_runtime_dtype(true, false, false), DType::BF16);
3192 assert_eq!(flux_runtime_dtype(false, false, true), DType::F32);
3193 }
3194
3195 #[test]
3196 fn flux_runtime_dtype_quantized_matches_gpu_policy() {
3197 assert_eq!(flux_runtime_dtype(true, true, false), DType::BF16);
3198 assert_eq!(flux_runtime_dtype(false, true, false), DType::F32);
3199 assert_eq!(flux_runtime_dtype(true, true, true), DType::BF16);
3200 assert_eq!(flux_runtime_dtype(false, true, true), DType::F32);
3201 }
3202
3203 #[test]
3204 fn fp8_cache_path_includes_file_size() {
3205 let dir = std::env::temp_dir().join(format!("mold-cache-test-{}", std::process::id()));
3207 std::fs::create_dir_all(&dir).unwrap();
3208 let fp8_file = dir.join("transformer.safetensors");
3209 std::fs::write(&fp8_file, vec![0u8; 1024]).unwrap();
3210
3211 let cache_path = super::fp8_gguf_cache_path(&fp8_file);
3212 let filename = cache_path.file_name().unwrap().to_str().unwrap();
3213
3214 assert!(
3216 filename.contains("transformer"),
3217 "should contain stem: {filename}"
3218 );
3219 assert!(
3220 filename.contains("1024"),
3221 "should contain file size: {filename}"
3222 );
3223 assert!(
3224 filename.ends_with(".q8_0.gguf"),
3225 "should end with .q8_0.gguf: {filename}"
3226 );
3227
3228 std::fs::write(&fp8_file, vec![0u8; 2048]).unwrap();
3230 let cache_path2 = super::fp8_gguf_cache_path(&fp8_file);
3231 assert_ne!(
3232 cache_path, cache_path2,
3233 "different file sizes should produce different cache paths"
3234 );
3235
3236 std::fs::remove_dir_all(&dir).ok();
3237 }
3238
3239 #[test]
3240 fn fp8_q8_cache_quantizes_only_block_aligned_last_dim() {
3241 assert!(super::q8_0_can_quantize_dims(&[3072, 3072]));
3242 assert!(super::q8_0_can_quantize_dims(&[1, 64]));
3243 assert!(
3244 !super::q8_0_can_quantize_dims(&[256, 256, 3, 3]),
3245 "conv kernels have total elements divisible by 32, but Q8_0 \
3246 requires the last dimension itself to be block-aligned"
3247 );
3248 assert!(!super::q8_0_can_quantize_dims(&[512, 512, 1, 1]));
3249 assert!(!super::q8_0_can_quantize_dims(&[3072]));
3250 }
3251
3252 #[test]
3253 fn fp8_q8_cache_skips_bundled_text_encoder_and_scalar_tensors() {
3254 assert!(super::fp8_cache_should_skip_tensor(
3255 "text_encoders.clip_l.logit_scale",
3256 &[]
3257 ));
3258 assert!(super::fp8_cache_should_skip_tensor(
3259 "text_encoders.t5xxl.encoder.block.0.layer.0.SelfAttention.q.weight",
3260 &[4096, 4096]
3261 ));
3262 assert!(super::fp8_cache_should_skip_tensor("some.scalar", &[]));
3263 assert!(!super::fp8_cache_should_skip_tensor(
3264 "double_blocks.0.img_attn.qkv.weight",
3265 &[9216, 3072]
3266 ));
3267 }
3268
3269 #[test]
3270 fn fp8_cache_path_lives_under_cache_flux_q8() {
3271 let path = std::path::Path::new("/some/model/my-model.safetensors");
3272 let cache_path = super::fp8_gguf_cache_path(path);
3274 let cache_str = cache_path.to_str().unwrap();
3275 assert!(
3276 cache_str.contains("cache/flux-q8"),
3277 "cache should be under cache/flux-q8: {cache_str}"
3278 );
3279 }
3280
3281 #[test]
3282 fn fp8_cache_temp_paths_are_unique_per_writer() {
3283 let cache_path =
3284 std::path::Path::new("/tmp/agfluxSchnell_realistic23-1234-deadbeef.q8_0.gguf");
3285
3286 let first = super::fp8_gguf_tmp_path(cache_path);
3287 let second = super::fp8_gguf_tmp_path(cache_path);
3288
3289 assert_ne!(first, second);
3290 assert_ne!(first, cache_path);
3291 assert_ne!(second, cache_path);
3292 }
3293
3294 #[test]
3295 fn detects_schnell_from_uppercase_filename() {
3296 let engine = super::FluxEngine::new(
3297 "cv:1153358".to_string(),
3298 dummy_paths("agfluxSchnell_realistic23.safetensors"),
3299 None,
3300 None,
3301 LoadStrategy::Sequential,
3302 0,
3303 false,
3304 None,
3305 );
3306
3307 assert!(engine.detect_is_schnell());
3308 }
3309
3310 #[test]
3311 fn flux_vae_var_builder_accepts_vae_prefix() {
3312 use safetensors::tensor::{serialize_to_file, Dtype as SafeDtype, TensorView};
3313 use std::collections::HashMap;
3314
3315 let dir = std::env::temp_dir().join(format!(
3316 "mold-flux-vae-prefix-{}-{}",
3317 std::process::id(),
3318 std::time::SystemTime::now()
3319 .duration_since(std::time::UNIX_EPOCH)
3320 .unwrap()
3321 .as_nanos()
3322 ));
3323 std::fs::create_dir_all(&dir).unwrap();
3324 let path = dir.join("vae-prefix.safetensors");
3325
3326 let data = vec![0u8; 128 * 3 * 3 * 3 * std::mem::size_of::<f32>()];
3327 let shape = vec![128, 3, 3, 3];
3328 let view = TensorView::new(SafeDtype::F32, shape, &data).unwrap();
3329 let mut tensors = HashMap::new();
3330 tensors.insert("vae.encoder.conv_in.weight".to_string(), view);
3331 serialize_to_file(&tensors, &None, &path).unwrap();
3332
3333 let vb = crate::weight_loader::load_safetensors_with_progress(
3334 std::slice::from_ref(&path),
3335 DType::F32,
3336 &Device::Cpu,
3337 "test VAE",
3338 &crate::progress::ProgressReporter::default(),
3339 )
3340 .unwrap();
3341 let vb = super::flux_vae_var_builder(vb);
3342
3343 assert!(vb.contains_tensor("encoder.conv_in.weight"));
3344
3345 std::fs::remove_dir_all(&dir).ok();
3346 }
3347
3348 fn write_test_gguf(path: &std::path::Path, tensor_names: &[&str]) {
3353 let device = Device::Cpu;
3354 let qtensors: Vec<(String, candle_core::quantized::QTensor)> = tensor_names
3355 .iter()
3356 .map(|name| {
3357 let t = Tensor::zeros(1, DType::F32, &device).unwrap();
3358 let qt = candle_core::quantized::QTensor::quantize(
3359 &t,
3360 candle_core::quantized::GgmlDType::F32,
3361 )
3362 .unwrap();
3363 (name.to_string(), qt)
3364 })
3365 .collect();
3366 let refs: Vec<(&str, &candle_core::quantized::QTensor)> =
3367 qtensors.iter().map(|(n, q)| (n.as_str(), q)).collect();
3368 let file = std::fs::File::create(path).unwrap();
3369 let mut writer = std::io::BufWriter::new(file);
3370 candle_core::quantized::gguf_file::write(&mut writer, &[], &refs).unwrap();
3371 }
3372
3373 #[test]
3374 fn gguf_has_embeddings_true_for_complete() {
3375 let dir =
3376 std::env::temp_dir().join(format!("mold-emb-test-complete-{}", std::process::id()));
3377 std::fs::create_dir_all(&dir).unwrap();
3378 let path = dir.join("complete.gguf");
3379 write_test_gguf(
3380 &path,
3381 &[
3382 "img_in.weight",
3383 "img_in.bias",
3384 "double_blocks.0.img_mod.lin.weight",
3385 ],
3386 );
3387 assert!(super::gguf_has_embeddings(&path).unwrap());
3388 std::fs::remove_dir_all(&dir).ok();
3389 }
3390
3391 #[test]
3392 fn gguf_has_embeddings_false_for_incomplete() {
3393 let dir =
3394 std::env::temp_dir().join(format!("mold-emb-test-incomplete-{}", std::process::id()));
3395 std::fs::create_dir_all(&dir).unwrap();
3396 let path = dir.join("incomplete.gguf");
3397 write_test_gguf(
3398 &path,
3399 &[
3400 "double_blocks.0.img_mod.lin.weight",
3401 "single_blocks.0.linear1.weight",
3402 "txt_in.weight",
3403 ],
3404 );
3405 assert!(!super::gguf_has_embeddings(&path).unwrap());
3406 std::fs::remove_dir_all(&dir).ok();
3407 }
3408
3409 #[test]
3410 fn embedding_patched_cache_path_format() {
3411 let dir = std::env::temp_dir().join(format!("mold-emb-cache-fmt-{}", std::process::id()));
3412 std::fs::create_dir_all(&dir).unwrap();
3413 let gguf_file = dir.join("ultrareal.gguf");
3414 std::fs::write(&gguf_file, vec![0u8; 512]).unwrap();
3415
3416 let cache_path = super::embedding_patched_cache_path(&gguf_file);
3417 let cache_str = cache_path.to_str().unwrap();
3418 assert!(
3419 cache_str.contains("cache/flux-embeddings"),
3420 "should be under cache/flux-embeddings: {cache_str}"
3421 );
3422 let filename = cache_path.file_name().unwrap().to_str().unwrap();
3423 assert!(
3424 filename.contains("ultrareal"),
3425 "should contain stem: {filename}"
3426 );
3427 assert!(
3428 filename.contains("512"),
3429 "should contain file size: {filename}"
3430 );
3431 assert!(
3432 filename.ends_with(".patched.gguf"),
3433 "should end with .patched.gguf: {filename}"
3434 );
3435
3436 std::fs::remove_dir_all(&dir).ok();
3437 }
3438
3439 #[test]
3440 fn ensure_gguf_embeddings_noop_for_complete() {
3441 let dir = std::env::temp_dir().join(format!("mold-emb-noop-{}", std::process::id()));
3442 std::fs::create_dir_all(&dir).unwrap();
3443 let path = dir.join("complete.gguf");
3444
3445 write_test_gguf(
3447 &path,
3448 &["img_in.weight", "double_blocks.0.img_mod.lin.weight"],
3449 );
3450
3451 let progress = crate::progress::ProgressReporter::default();
3452 let result = super::ensure_gguf_embeddings(&path, false, &progress, None).unwrap();
3453
3454 assert_eq!(result, path);
3456
3457 std::fs::remove_dir_all(&dir).ok();
3458 }
3459
3460 #[test]
3461 fn ensure_gguf_embeddings_patches_incomplete_with_reference() {
3462 let dir = std::env::temp_dir().join(format!("mold-emb-patch-{}", std::process::id()));
3465 std::fs::create_dir_all(&dir).unwrap();
3466
3467 let incomplete_path = dir.join("ultrareal-test.gguf");
3469 write_test_gguf(
3470 &incomplete_path,
3471 &[
3472 "double_blocks.0.img_mod.lin.weight",
3473 "single_blocks.0.linear1.weight",
3474 "txt_in.weight",
3475 "txt_in.bias",
3476 "final_layer.linear.weight",
3477 ],
3478 );
3479
3480 let models_dir = dir.join("models");
3483 let ref_model_dir = models_dir.join("flux-dev-q8");
3484 std::fs::create_dir_all(&ref_model_dir).unwrap();
3485 let ref_path = ref_model_dir.join("flux1-dev-Q8_0.gguf");
3486
3487 let mut all_tensors: Vec<&str> = super::FLUX_EMBEDDING_TENSORS.to_vec();
3489 all_tensors.extend_from_slice(super::FLUX_GUIDANCE_EMBEDDING_TENSORS);
3490 all_tensors.extend_from_slice(&[
3491 "double_blocks.0.img_mod.lin.weight",
3492 "txt_in.weight",
3493 "txt_in.bias",
3494 ]);
3495 write_test_gguf(&ref_path, &all_tensors);
3496
3497 let progress = crate::progress::ProgressReporter::default();
3498 let result =
3499 super::ensure_gguf_embeddings(&incomplete_path, false, &progress, Some(&models_dir));
3500
3501 let patched_path = result.unwrap();
3502 assert_ne!(
3503 patched_path, incomplete_path,
3504 "should return a different cached path"
3505 );
3506 assert!(patched_path.exists(), "patched GGUF should exist on disk");
3507 assert!(
3508 patched_path.to_str().unwrap().contains("flux-embeddings"),
3509 "patched file should be in flux-embeddings cache"
3510 );
3511
3512 assert!(
3514 super::gguf_has_embeddings(&patched_path).unwrap(),
3515 "patched GGUF should have embeddings"
3516 );
3517
3518 std::fs::remove_dir_all(&dir).ok();
3520 std::fs::remove_file(&patched_path).ok();
3521 let _ = std::fs::remove_dir(patched_path.parent().unwrap());
3522 }
3523
3524 #[test]
3525 fn ensure_gguf_embeddings_cache_is_reused() {
3526 let dir = std::env::temp_dir().join(format!("mold-emb-reuse-{}", std::process::id()));
3528 std::fs::create_dir_all(&dir).unwrap();
3529
3530 let incomplete_path = dir.join("model-for-cache.gguf");
3531 write_test_gguf(&incomplete_path, &["double_blocks.0.img_mod.lin.weight"]);
3532
3533 let cache_path = super::embedding_patched_cache_path(&incomplete_path);
3535 std::fs::create_dir_all(cache_path.parent().unwrap()).unwrap();
3536 write_test_gguf(
3537 &cache_path,
3538 &["img_in.weight", "double_blocks.0.img_mod.lin.weight"],
3539 );
3540
3541 let progress = crate::progress::ProgressReporter::default();
3542 let result =
3543 super::ensure_gguf_embeddings(&incomplete_path, true, &progress, None).unwrap();
3544
3545 assert_eq!(result, cache_path, "should return cached file");
3546
3547 std::fs::remove_dir_all(&dir).ok();
3549 std::fs::remove_file(&cache_path).ok();
3550 let _ = std::fs::remove_dir(cache_path.parent().unwrap());
3552 }
3553
3554 #[test]
3555 fn find_flux_reference_skips_schnell_when_dev_needed() {
3556 let dir = std::env::temp_dir().join(format!(
3560 "mold-ref-picker-{}-{}",
3561 std::process::id(),
3562 std::time::SystemTime::now()
3563 .duration_since(std::time::UNIX_EPOCH)
3564 .unwrap()
3565 .as_nanos()
3566 ));
3567 let models_dir = dir.join("models");
3568 let schnell_dir = models_dir.join("flux-schnell-q8");
3569 std::fs::create_dir_all(&schnell_dir).unwrap();
3570 let schnell_path = schnell_dir.join("flux1-schnell-Q8_0.gguf");
3571
3572 let mut schnell_tensors: Vec<&str> = super::FLUX_EMBEDDING_TENSORS.to_vec();
3574 schnell_tensors.push("double_blocks.0.img_mod.lin.weight");
3575 write_test_gguf(&schnell_path, &schnell_tensors);
3576
3577 let result = super::find_flux_reference_gguf(true, Some(&models_dir));
3579 assert!(
3580 result.is_none(),
3581 "schnell must not be picked as reference for dev targets: got {result:?}"
3582 );
3583
3584 let result = super::find_flux_reference_gguf(false, Some(&models_dir));
3586 assert_eq!(result.as_deref(), Some(schnell_path.as_path()));
3587
3588 std::fs::remove_dir_all(&dir).ok();
3589 }
3590
3591 #[test]
3592 fn find_flux_reference_accepts_dev_candidate_with_guidance() {
3593 let dir = std::env::temp_dir().join(format!(
3597 "mold-ref-dev-{}-{}",
3598 std::process::id(),
3599 std::time::SystemTime::now()
3600 .duration_since(std::time::UNIX_EPOCH)
3601 .unwrap()
3602 .as_nanos()
3603 ));
3604 let models_dir = dir.join("models");
3605 let dev_dir = models_dir.join("flux-dev-q8");
3606 std::fs::create_dir_all(&dev_dir).unwrap();
3607 let dev_path = dev_dir.join("flux1-dev-Q8_0.gguf");
3608
3609 let incomplete: Vec<&str> = super::FLUX_EMBEDDING_TENSORS.to_vec();
3611 write_test_gguf(&dev_path, &incomplete);
3612 assert!(
3613 super::find_flux_reference_gguf(true, Some(&models_dir)).is_none(),
3614 "dev candidate without guidance_in must be rejected for dev targets"
3615 );
3616
3617 let mut complete: Vec<&str> = super::FLUX_EMBEDDING_TENSORS.to_vec();
3619 complete.extend_from_slice(super::FLUX_GUIDANCE_EMBEDDING_TENSORS);
3620 write_test_gguf(&dev_path, &complete);
3621 let picked = super::find_flux_reference_gguf(true, Some(&models_dir))
3622 .expect("complete dev reference must be accepted");
3623 assert_eq!(picked, dev_path);
3624
3625 let picked = super::find_flux_reference_gguf(false, Some(&models_dir))
3627 .expect("dev candidate satisfies schnell targets too");
3628 assert_eq!(picked, dev_path);
3629
3630 std::fs::remove_dir_all(&dir).ok();
3631 }
3632
3633 #[test]
3634 fn find_flux_reference_accepts_krea_when_no_base_dev() {
3635 let dir = std::env::temp_dir().join(format!(
3639 "mold-ref-krea-{}-{}",
3640 std::process::id(),
3641 std::time::SystemTime::now()
3642 .duration_since(std::time::UNIX_EPOCH)
3643 .unwrap()
3644 .as_nanos()
3645 ));
3646 let models_dir = dir.join("models");
3647 let krea_dir = models_dir.join("flux-krea-q8");
3648 std::fs::create_dir_all(&krea_dir).unwrap();
3649 let krea_path = krea_dir.join("flux1-krea-dev-Q8_0.gguf");
3650
3651 let mut complete: Vec<&str> = super::FLUX_EMBEDDING_TENSORS.to_vec();
3652 complete.extend_from_slice(super::FLUX_GUIDANCE_EMBEDDING_TENSORS);
3653 write_test_gguf(&krea_path, &complete);
3654
3655 let picked = super::find_flux_reference_gguf(true, Some(&models_dir))
3656 .expect("complete flux-krea reference must be accepted for dev targets");
3657 assert_eq!(picked, krea_path);
3658
3659 std::fs::remove_dir_all(&dir).ok();
3660 }
3661
3662 #[test]
3663 fn embedding_tensor_names_are_exhaustive() {
3664 let all_embedding_names: Vec<&str> = super::FLUX_EMBEDDING_TENSORS
3671 .iter()
3672 .chain(super::FLUX_GUIDANCE_EMBEDDING_TENSORS.iter())
3673 .copied()
3674 .collect();
3675
3676 assert!(all_embedding_names.contains(&"img_in.weight"));
3678 assert!(all_embedding_names.contains(&"img_in.bias"));
3679
3680 assert!(all_embedding_names.contains(&"time_in.in_layer.weight"));
3682 assert!(all_embedding_names.contains(&"time_in.in_layer.bias"));
3683 assert!(all_embedding_names.contains(&"time_in.out_layer.weight"));
3684 assert!(all_embedding_names.contains(&"time_in.out_layer.bias"));
3685
3686 assert!(all_embedding_names.contains(&"vector_in.in_layer.weight"));
3688 assert!(all_embedding_names.contains(&"vector_in.in_layer.bias"));
3689 assert!(all_embedding_names.contains(&"vector_in.out_layer.weight"));
3690 assert!(all_embedding_names.contains(&"vector_in.out_layer.bias"));
3691
3692 assert!(all_embedding_names.contains(&"guidance_in.in_layer.weight"));
3694 assert!(all_embedding_names.contains(&"guidance_in.in_layer.bias"));
3695 assert!(all_embedding_names.contains(&"guidance_in.out_layer.weight"));
3696 assert!(all_embedding_names.contains(&"guidance_in.out_layer.bias"));
3697
3698 assert_eq!(all_embedding_names.len(), 14);
3700 }
3701}