1use anyhow::{bail, Result};
17use candle_core::{DType, Device, IndexOp, Tensor};
18use candle_nn::VarBuilder;
19use mold_core::{GenerateRequest, GenerateResponse, ImageData, LoraWeight, ModelPaths};
20use std::collections::HashMap;
21use std::path::{Path, PathBuf};
22use std::sync::{Arc, Mutex};
23use std::time::Instant;
24use tokenizers::Tokenizer;
25
26use super::sampling::{self, Flux2State};
27use super::transformer::{Flux2Config, Flux2TransformerWrapper};
28use super::vae::{Flux2AutoEncoder, Flux2VaeConfig};
29
30use crate::cache::{
31 clear_cache, get_or_insert_cached_tensor, prompt_text_key, restore_cached_tensor, CachedTensor,
32 LruCache, DEFAULT_PROMPT_CACHE_CAPACITY,
33};
34use crate::device::{
35 check_memory_budget, effective_device_ref, fmt_gb, free_vram_bytes, memory_status_string,
36 preflight_memory_check, usable_free_vram_bytes,
37};
38use crate::encoders;
39use crate::engine::{rand_seed, InferenceEngine, LoadStrategy};
40use crate::engine_base::EngineBase;
41use crate::image::{build_output_metadata, encode_image};
42use crate::progress::{ProgressCallback, ProgressReporter};
43
44struct LoadedFlux2 {
50 transformer: Option<Flux2TransformerWrapper>,
52 text_encoder: encoders::qwen3::Qwen3Encoder,
53 vae: Flux2AutoEncoder,
54 device: Device,
56 dtype: DType,
57 vae_dtype: DType,
61}
62
63pub struct Flux2Engine {
69 base: EngineBase<LoadedFlux2>,
70 qwen3_variant: Option<String>,
72 offload: bool,
76 prompt_cache: Mutex<LruCache<String, CachedTensor>>,
77 pending_placement: Option<mold_core::types::DevicePlacement>,
80 pending_loras: Vec<LoraWeight>,
85 shared_pool: Option<Arc<Mutex<crate::shared_pool::SharedPool>>>,
86}
87
88pub(crate) fn effective_flux2_loras(req: &GenerateRequest) -> Vec<LoraWeight> {
93 const ZERO_SCALE_EPS: f64 = 1e-8;
95
96 let raw: Vec<LoraWeight> = if let Some(plural) = &req.loras {
97 if !plural.is_empty() {
98 plural.clone()
99 } else {
100 req.lora.iter().cloned().collect()
101 }
102 } else {
103 req.lora.iter().cloned().collect()
104 };
105 raw.into_iter()
106 .filter(|w| {
107 let keep = w.scale.abs() > ZERO_SCALE_EPS;
108 if !keep {
109 tracing::debug!(
110 path = w.path.as_str(),
111 scale = w.scale,
112 "dropping zero-scale Flux.2 LoRA"
113 );
114 }
115 keep
116 })
117 .collect()
118}
119
120#[derive(Debug, PartialEq, Eq)]
121enum Flux2OffloadDecision {
122 Disabled,
123 Selected,
124 Unsupported(&'static str),
125}
126
127fn flux2_offload_decision(
128 forced_offload: bool,
129 is_gguf: bool,
130 has_lora: bool,
131) -> Flux2OffloadDecision {
132 if !forced_offload {
133 return Flux2OffloadDecision::Disabled;
134 }
135 if is_gguf {
136 return Flux2OffloadDecision::Unsupported(
137 "Flux.2 block-level offload is only planned for BF16/FP transformers; \
138 GGUF variants already use quantized transformer paths",
139 );
140 }
141 if has_lora {
142 return Flux2OffloadDecision::Unsupported(
143 "Flux.2 block-level offload with LoRA is not wired yet; \
144 LoRA merge/bypass semantics need a dedicated offload design",
145 );
146 }
147 Flux2OffloadDecision::Selected
148}
149
150impl Flux2Engine {
151 pub fn new(
153 model_name: String,
154 paths: ModelPaths,
155 qwen3_variant: Option<String>,
156 load_strategy: LoadStrategy,
157 gpu_ordinal: usize,
158 offload: bool,
159 shared_pool: Option<Arc<Mutex<crate::shared_pool::SharedPool>>>,
160 ) -> Self {
161 Self {
162 base: EngineBase::new(model_name, paths, load_strategy, gpu_ordinal),
163 qwen3_variant,
164 offload,
165 prompt_cache: Mutex::new(LruCache::new(DEFAULT_PROMPT_CACHE_CAPACITY)),
166 pending_placement: None,
167 pending_loras: Vec::new(),
168 shared_pool,
169 }
170 }
171
172 #[allow(clippy::too_many_arguments)]
183 pub fn from_single_file(
184 model_name: String,
185 transformer_path: PathBuf,
186 vae_path: PathBuf,
187 text_encoder_files: Vec<PathBuf>,
188 text_tokenizer: Option<PathBuf>,
189 qwen3_variant: Option<String>,
190 load_strategy: LoadStrategy,
191 gpu_ordinal: usize,
192 offload: bool,
193 shared_pool: Option<Arc<Mutex<crate::shared_pool::SharedPool>>>,
194 ) -> Result<Self> {
195 if !transformer_path.exists() {
196 bail!(
197 "single-file Flux.2 checkpoint not found: {}",
198 transformer_path.display()
199 );
200 }
201
202 let paths = ModelPaths {
203 transformer: transformer_path,
204 transformer_shards: Vec::new(),
205 vae: vae_path,
206 spatial_upscaler: None,
207 temporal_upscaler: None,
208 distilled_lora: None,
209 t5_encoder: None,
210 clip_encoder: None,
211 t5_tokenizer: None,
212 clip_tokenizer: None,
213 clip_encoder_2: None,
214 clip_tokenizer_2: None,
215 text_encoder_files,
216 text_tokenizer,
217 decoder: None,
218 };
219
220 Ok(Self {
221 base: EngineBase::new(model_name, paths, load_strategy, gpu_ordinal),
222 qwen3_variant,
223 offload,
224 prompt_cache: Mutex::new(LruCache::new(DEFAULT_PROMPT_CACHE_CAPACITY)),
225 pending_placement: None,
226 pending_loras: Vec::new(),
227 shared_pool,
228 })
229 }
230
231 fn resolve_config(&self) -> Flux2Config {
240 if let Some(cfg) = self.detect_config_from_checkpoint() {
241 return cfg;
242 }
243 if self.base.model_name.to_lowercase().contains("9b") {
244 Flux2Config::klein_9b()
245 } else {
246 Flux2Config::klein()
247 }
248 }
249
250 fn detect_config_from_checkpoint(&self) -> Option<Flux2Config> {
254 if !self.base.paths.transformer_shards.is_empty() {
255 return None;
256 }
257 let path = &self.base.paths.transformer;
258 let is_safetensors = path
259 .extension()
260 .and_then(|e| e.to_str())
261 .is_some_and(|e| e.eq_ignore_ascii_case("safetensors"));
262 if !is_safetensors {
263 return None;
264 }
265 match super::single_file::detect_hidden_size(path) {
266 Ok(Some(4096)) => Some(Flux2Config::klein_9b()),
267 Ok(Some(3072)) => Some(Flux2Config::klein()),
268 _ => None,
270 }
271 }
272
273 fn is_9b(&self) -> bool {
277 if let Some(cfg) = self.detect_config_from_checkpoint() {
278 return cfg.hidden_size == 4096;
279 }
280 self.base.model_name.to_lowercase().contains("9b")
281 }
282
283 fn qwen3_size(&self) -> crate::encoders::variant_resolution::Qwen3Size {
285 if self.is_9b() {
286 crate::encoders::variant_resolution::Qwen3Size::B8
287 } else {
288 crate::encoders::variant_resolution::Qwen3Size::B4
289 }
290 }
291
292 fn qwen3_bf16_config(&self) -> encoders::qwen3_bf16::Qwen3BF16Config {
294 if self.is_9b() {
295 encoders::qwen3_bf16::Qwen3BF16Config::qwen3_8b()
296 } else {
297 encoders::qwen3_bf16::Qwen3BF16Config::qwen3_4b()
298 }
299 }
300
301 fn load_text_tokenizer(&self, tokenizer_path: &Path) -> Result<Arc<Tokenizer>> {
302 if let Some(shared_pool) = &self.shared_pool {
303 return shared_pool.lock().unwrap().load_tokenizer(tokenizer_path);
304 }
305 Tokenizer::from_file(tokenizer_path)
306 .map(Arc::new)
307 .map_err(|e| anyhow::anyhow!("failed to load Qwen3 tokenizer: {e}"))
308 }
309
310 fn load_vae_cpu_tensors(&self) -> Result<Option<Arc<HashMap<String, Tensor>>>> {
311 let Some(shared_pool) = &self.shared_pool else {
312 return Ok(None);
313 };
314 shared_pool
315 .lock()
316 .unwrap()
317 .load_safetensors_cpu_tensors(std::slice::from_ref(&self.base.paths.vae))
318 }
319
320 fn load_vae_var_builder<'a>(
321 &self,
322 dtype: DType,
323 device: &Device,
324 component: &str,
325 ) -> Result<VarBuilder<'a>> {
326 if let Some(tensors) = self.load_vae_cpu_tensors()? {
327 return Ok(crate::encoders::park::varbuilder_from_parked(
328 tensors.as_ref(),
329 dtype,
330 device,
331 ));
332 }
333
334 crate::weight_loader::load_safetensors_with_progress(
335 std::slice::from_ref(&self.base.paths.vae),
336 dtype,
337 device,
338 component,
339 &self.base.progress,
340 )
341 }
342
343 fn img2img_source_normalize_range() -> crate::img_utils::NormalizeRange {
344 crate::img_utils::NormalizeRange::MinusOneToOne
345 }
346
347 #[cfg(test)]
348 fn sequential_img2img_preencodes_source() -> bool {
349 true
350 }
351
352 fn uses_sequential_generate_path(&self) -> bool {
353 self.base.load_strategy == LoadStrategy::Sequential
354 || self.offload
355 || !self.pending_loras.is_empty()
356 }
357
358 fn load_sequential_vae(
359 &self,
360 device: &Device,
361 gpu_dtype: DType,
362 ) -> Result<(Flux2AutoEncoder, DType)> {
363 let vae_ref =
364 effective_device_ref(self.pending_placement.as_ref(), |adv| Some(adv.vae), false);
365 let vae_device = crate::device::resolve_device(Some(vae_ref), || Ok(device.clone()))?;
366 self.base.progress.stage_start("Loading VAE (GPU)");
367 let vae_stage = Instant::now();
368 let vae_cfg = Flux2VaeConfig::klein();
369 let vae_dtype = crate::device::resolve_vae_dtype(gpu_dtype);
372 let vae_vb = self.load_vae_var_builder(vae_dtype, &vae_device, "VAE")?;
373 let vae = Flux2AutoEncoder::new(&vae_cfg, vae_vb)?;
374 self.base
375 .progress
376 .stage_done("Loading VAE (GPU)", vae_stage.elapsed());
377 Ok((vae, vae_dtype))
378 }
379
380 fn validate_paths(&self) -> Result<std::path::PathBuf> {
382 let text_tokenizer_path = self
383 .base
384 .paths
385 .text_tokenizer
386 .as_ref()
387 .ok_or_else(|| anyhow::anyhow!("text tokenizer path required for Flux.2 models"))?;
388 if !text_tokenizer_path.exists() {
389 bail!(
390 "text tokenizer file not found: {}",
391 text_tokenizer_path.display()
392 );
393 }
394
395 let encoder_paths = self.text_encoder_paths();
396 if encoder_paths.is_empty() {
397 bail!("text encoder paths required for Flux.2 models");
398 }
399 for path in &encoder_paths {
400 if !path.exists() {
401 bail!("text encoder file not found: {}", path.display());
402 }
403 }
404
405 if !self.base.paths.transformer.exists() {
406 bail!(
407 "transformer file not found: {}",
408 self.base.paths.transformer.display()
409 );
410 }
411 if !self.base.paths.vae.exists() {
412 bail!("VAE file not found: {}", self.base.paths.vae.display());
413 }
414
415 Ok(text_tokenizer_path.clone())
416 }
417
418 fn is_gguf_transformer(&self) -> bool {
420 self.base
421 .paths
422 .transformer
423 .extension()
424 .and_then(|e| e.to_str())
425 .map(|e| e.eq_ignore_ascii_case("gguf"))
426 .unwrap_or(false)
427 }
428
429 fn load_transformer(
437 &self,
438 cfg: &Flux2Config,
439 gpu_dtype: DType,
440 device: &Device,
441 ) -> Result<(Flux2TransformerWrapper, &'static str)> {
442 let has_lora = !self.pending_loras.is_empty();
443 if self.is_gguf_transformer() {
444 if has_lora {
445 let adapters =
448 super::lora::load_lora_adapters(&self.pending_loras, &self.base.progress)?;
449 let specs: Vec<super::lora::Flux2LoraSpec<'_>> = adapters
450 .iter()
451 .zip(self.pending_loras.iter())
452 .map(|(adapter, w)| super::lora::Flux2LoraSpec {
453 adapter: adapter.as_ref(),
454 scale: w.scale,
455 path_hash: super::lora::lora_path_hash(&w.path),
456 })
457 .collect();
458 let gguf_vb = super::lora::gguf_lora_var_builder_flux2(
459 &self.base.paths.transformer,
460 &specs,
461 device,
462 &self.base.progress,
463 None,
464 )?;
465 return Ok((
466 Flux2TransformerWrapper::Quantized(
467 super::quantized_transformer::QuantizedFlux2Transformer::new(
468 cfg, gguf_vb, device,
469 )?,
470 ),
471 "Loading Flux.2 transformer (GPU, GGUF + LoRA)",
472 ));
473 }
474 let gguf_vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(
477 &self.base.paths.transformer,
478 device,
479 )?;
480 Ok((
481 Flux2TransformerWrapper::Quantized(
482 super::quantized_transformer::QuantizedFlux2Transformer::new(
483 cfg, gguf_vb, device,
484 )?,
485 ),
486 "Loading Flux.2 transformer (GPU, GGUF)",
487 ))
488 } else if self.is_bfl_native_single_file() {
489 tracing::info!(
494 path = %self.base.paths.transformer.display(),
495 "loading Flux.2 transformer from BFL-native single-file checkpoint"
496 );
497 let backend =
498 crate::loader::single_file_backend::SingleFileBackend::from_flux2_singlefile(
499 &self.base.paths.transformer,
500 cfg,
501 )?;
502 let backend: Box<dyn candle_nn::var_builder::SimpleBackend> = Box::new(backend);
503 if self.offload && !has_lora {
504 let flux_vb = candle_nn::VarBuilder::from_backend(backend, gpu_dtype, Device::Cpu);
505 return Ok((
506 Flux2TransformerWrapper::Offloaded(
507 super::transformer::OffloadedFlux2Transformer::new(cfg, flux_vb, device)?,
508 ),
509 "Loading Flux.2 transformer (offload, BF16, single-file remap)",
510 ));
511 }
512 let backend = if has_lora {
513 let adapters =
514 super::lora::load_lora_adapters(&self.pending_loras, &self.base.progress)?;
515 let specs: Vec<super::lora::Flux2LoraSpec<'_>> = adapters
516 .iter()
517 .zip(self.pending_loras.iter())
518 .map(|(adapter, w)| super::lora::Flux2LoraSpec {
519 adapter: adapter.as_ref(),
520 scale: w.scale,
521 path_hash: super::lora::lora_path_hash(&w.path),
522 })
523 .collect();
524 super::lora::wrap_backend_with_lora(
525 backend,
526 &specs,
527 super::lora::Flux2KeySpace::Diffusers,
528 &self.base.progress,
529 None,
530 )?
531 } else {
532 backend
533 };
534 let flux_vb = candle_nn::VarBuilder::from_backend(backend, gpu_dtype, device.clone());
535 let label = if has_lora {
536 "Loading Flux.2 transformer (GPU, BF16, single-file remap + LoRA)"
537 } else {
538 "Loading Flux.2 transformer (GPU, BF16, single-file remap)"
539 };
540 Ok((
541 Flux2TransformerWrapper::BF16(super::transformer::Flux2Transformer::new(
542 cfg, flux_vb,
543 )?),
544 label,
545 ))
546 } else {
547 let xformer_paths = if !self.base.paths.transformer_shards.is_empty() {
548 self.base.paths.transformer_shards.clone()
549 } else {
550 vec![self.base.paths.transformer.clone()]
551 };
552 let (flux_vb, offloaded_label) = if has_lora {
553 use candle_core::safetensors::MmapedSafetensors;
558 let path_refs: Vec<&std::path::Path> =
559 xformer_paths.iter().map(|p| p.as_path()).collect();
560 let st = unsafe { MmapedSafetensors::multi(&path_refs)? };
561 struct MmapBackend {
562 st: MmapedSafetensors,
563 }
564 impl candle_nn::var_builder::SimpleBackend for MmapBackend {
565 fn get(
566 &self,
567 _s: candle_core::Shape,
568 name: &str,
569 _h: candle_nn::Init,
570 dtype: DType,
571 dev: &Device,
572 ) -> candle_core::Result<Tensor> {
573 let t = self.st.load(name, dev)?;
574 if t.dtype() != dtype {
575 t.to_dtype(dtype)
576 } else {
577 Ok(t)
578 }
579 }
580 fn get_unchecked(
581 &self,
582 name: &str,
583 dtype: DType,
584 dev: &Device,
585 ) -> candle_core::Result<Tensor> {
586 let t = self.st.load(name, dev)?;
587 if t.dtype() != dtype {
588 t.to_dtype(dtype)
589 } else {
590 Ok(t)
591 }
592 }
593 fn contains_tensor(&self, name: &str) -> bool {
594 self.st.get(name).is_ok()
595 }
596 }
597 let inner: Box<dyn candle_nn::var_builder::SimpleBackend> =
598 Box::new(MmapBackend { st });
599 let adapters =
600 super::lora::load_lora_adapters(&self.pending_loras, &self.base.progress)?;
601 let specs: Vec<super::lora::Flux2LoraSpec<'_>> = adapters
602 .iter()
603 .zip(self.pending_loras.iter())
604 .map(|(adapter, w)| super::lora::Flux2LoraSpec {
605 adapter: adapter.as_ref(),
606 scale: w.scale,
607 path_hash: super::lora::lora_path_hash(&w.path),
608 })
609 .collect();
610 let wrapped = super::lora::wrap_backend_with_lora(
611 inner,
612 &specs,
613 super::lora::Flux2KeySpace::Diffusers,
614 &self.base.progress,
615 None,
616 )?;
617 (
618 candle_nn::VarBuilder::from_backend(wrapped, gpu_dtype, device.clone()),
619 None,
620 )
621 } else if self.offload {
622 (
623 crate::weight_loader::load_safetensors_with_progress(
624 &xformer_paths,
625 gpu_dtype,
626 &Device::Cpu,
627 "Flux.2 transformer (offload blocks)",
628 &self.base.progress,
629 )?,
630 Some("Loading Flux.2 transformer (offload, BF16)"),
631 )
632 } else {
633 (
634 crate::weight_loader::load_safetensors_with_progress(
635 &xformer_paths,
636 gpu_dtype,
637 device,
638 "Flux.2 transformer",
639 &self.base.progress,
640 )?,
641 None,
642 )
643 };
644 if let Some(label) = offloaded_label {
645 return Ok((
646 Flux2TransformerWrapper::Offloaded(
647 super::transformer::OffloadedFlux2Transformer::new(cfg, flux_vb, device)?,
648 ),
649 label,
650 ));
651 }
652 let label = if has_lora {
653 "Loading Flux.2 transformer (GPU, BF16 + LoRA)"
654 } else {
655 "Loading Flux.2 transformer (GPU, BF16)"
656 };
657 Ok((
658 Flux2TransformerWrapper::BF16(super::transformer::Flux2Transformer::new(
659 cfg, flux_vb,
660 )?),
661 label,
662 ))
663 }
664 }
665
666 fn is_bfl_native_single_file(&self) -> bool {
675 if !self.base.paths.transformer_shards.is_empty() {
676 return false;
677 }
678 let path = &self.base.paths.transformer;
679 let is_safetensors = path
680 .extension()
681 .and_then(|e| e.to_str())
682 .is_some_and(|e| e.eq_ignore_ascii_case("safetensors"));
683 if !is_safetensors {
684 return false;
685 }
686 matches!(
687 super::single_file::detect_format(path),
688 Ok(super::single_file::Flux2SingleFileFormat::BflNative)
689 | Ok(super::single_file::Flux2SingleFileFormat::BflNativeRoot)
690 | Ok(super::single_file::Flux2SingleFileFormat::Nvfp4)
691 )
692 }
693
694 fn reload_transformer_if_needed(&mut self) -> Result<()> {
697 let needs_reload = self
698 .base
699 .loaded
700 .as_ref()
701 .map(|l| l.transformer.is_none())
702 .unwrap_or(false);
703
704 if needs_reload {
705 let cfg = self.resolve_config();
706 self.base
707 .progress
708 .stage_start("Reloading Flux.2 transformer");
709 let reload_start = Instant::now();
710 let (transformer, _label) = self.load_transformer(
711 &cfg,
712 self.base.loaded.as_ref().unwrap().dtype,
713 &self.base.loaded.as_ref().unwrap().device.clone(),
714 )?;
715 self.base.loaded.as_mut().unwrap().transformer = Some(transformer);
716 self.base
717 .progress
718 .stage_done("Reloading Flux.2 transformer", reload_start.elapsed());
719 }
720 Ok(())
721 }
722
723 fn should_delay_transformer_reload_for_prompt_encode(
724 load_strategy: LoadStrategy,
725 transformer_loaded: bool,
726 ) -> bool {
727 load_strategy == LoadStrategy::Eager && !transformer_loaded
728 }
729
730 fn text_encoder_paths(&self) -> Vec<std::path::PathBuf> {
732 if !self.base.paths.text_encoder_files.is_empty() {
733 self.base.paths.text_encoder_files.clone()
734 } else {
735 self.base
737 .paths
738 .t5_encoder
739 .as_ref()
740 .map(|p| vec![p.clone()])
741 .unwrap_or_default()
742 }
743 }
744
745 const QWEN3_HIDDEN_LAYERS: [usize; 3] = [9, 18, 27];
754
755 fn encode_and_stack(
756 encoder: &mut encoders::qwen3::Qwen3Encoder,
757 prompt: &str,
758 target_device: &Device,
759 target_dtype: DType,
760 ) -> Result<Tensor> {
761 let (stacked, _token_count) = encoder.encode_with_layers(
763 prompt,
764 target_device,
765 target_dtype,
766 &Self::QWEN3_HIDDEN_LAYERS,
767 )?;
768 Ok(stacked)
769 }
770
771 fn encode_prompt_cached(
772 progress: &ProgressReporter,
773 prompt_cache: &Mutex<LruCache<String, CachedTensor>>,
774 encoder: &mut encoders::qwen3::Qwen3Encoder,
775 prompt: &str,
776 target_device: &Device,
777 target_dtype: DType,
778 ) -> Result<Tensor> {
779 let cache_key = prompt_text_key(prompt);
780 let (txt_emb, cache_hit) = get_or_insert_cached_tensor(
781 prompt_cache,
782 cache_key,
783 target_device,
784 target_dtype,
785 || {
786 progress.stage_start("Encoding prompt (Qwen3)");
787 let encode_start = Instant::now();
788 let txt_emb = Self::encode_and_stack(encoder, prompt, target_device, target_dtype)?;
789 progress.stage_done("Encoding prompt (Qwen3)", encode_start.elapsed());
790 Ok(txt_emb)
791 },
792 )?;
793 if cache_hit {
794 progress.cache_hit("prompt conditioning");
795 }
796 Ok(txt_emb)
797 }
798
799 pub fn load(&mut self) -> Result<()> {
808 if self.base.loaded.is_some() {
809 return Ok(());
810 }
811
812 if self.base.load_strategy == LoadStrategy::Sequential {
814 return Ok(());
815 }
816
817 tracing::info!(model = %self.base.model_name, "loading Flux.2 Klein model components...");
818
819 let text_tokenizer_path = self.validate_paths()?;
820
821 let cpu = Device::Cpu;
822 let transformer_ref = effective_device_ref(
823 self.pending_placement.as_ref(),
824 |adv| Some(adv.transformer),
825 false,
826 );
827 let device = crate::device::resolve_device(Some(transformer_ref), || {
828 crate::device::create_device(self.base.gpu_ordinal, &self.base.progress)
829 })?;
830 let gpu_dtype = crate::engine::gpu_dtype(&device);
831
832 tracing::info!("GPU device: {:?}, GPU dtype: {:?}", device, gpu_dtype);
833
834 let flux2_cfg = self.resolve_config();
836 let xformer_stage = Instant::now();
837 let (transformer, xformer_label) = self.load_transformer(&flux2_cfg, gpu_dtype, &device)?;
838 self.base
839 .progress
840 .stage_done(xformer_label, xformer_stage.elapsed());
841
842 let vae_ref =
844 effective_device_ref(self.pending_placement.as_ref(), |adv| Some(adv.vae), false);
845 let vae_device = crate::device::resolve_device(Some(vae_ref), || Ok(device.clone()))?;
846 self.base.progress.stage_start("Loading VAE (GPU)");
847 let vae_stage = Instant::now();
848 tracing::info!(path = %self.base.paths.vae.display(), "loading VAE on GPU...");
849 let vae_cfg = Flux2VaeConfig::klein();
850 let vae_dtype = crate::device::resolve_vae_dtype(gpu_dtype);
852 let vae_vb = self.load_vae_var_builder(vae_dtype, &vae_device, "VAE")?;
853 let vae = Flux2AutoEncoder::new(&vae_cfg, vae_vb)?;
854 self.base
855 .progress
856 .stage_done("Loading VAE (GPU)", vae_stage.elapsed());
857 tracing::info!("VAE loaded on GPU");
858
859 let free_raw = free_vram_bytes(self.base.gpu_ordinal).unwrap_or(0);
863 let free = usable_free_vram_bytes(self.base.gpu_ordinal).unwrap_or(0);
864 if free_raw > 0 {
865 self.base.progress.info(&format!(
866 "Free VRAM after transformer+VAE: {}",
867 fmt_gb(free_raw)
868 ));
869 }
870
871 self.base.progress.stage_start("Selecting Qwen3 encoder");
872 let resolve_start = Instant::now();
873 let qwen3_size = self.qwen3_size();
874 let (encoder_paths, is_gguf, on_gpu, device_label) = {
875 let bf16_paths = self.text_encoder_paths();
876 let have_bf16 = !bf16_paths.is_empty() && bf16_paths.iter().all(|p| p.exists());
877 crate::encoders::variant_resolution::resolve_qwen3_variant(
878 &self.base.progress,
879 self.qwen3_variant.as_deref(),
880 &device,
881 free,
882 &bf16_paths,
883 have_bf16,
884 true,
885 qwen3_size,
886 )?
887 };
888 self.base
889 .progress
890 .stage_done("Selecting Qwen3 encoder", resolve_start.elapsed());
891
892 let qwen3_ref = effective_device_ref(self.pending_placement.as_ref(), |adv| adv.qwen, true);
893 let auto_enc_device = if on_gpu { device.clone() } else { cpu.clone() };
894 let enc_device_owned =
895 crate::device::resolve_device(Some(qwen3_ref), || Ok(auto_enc_device.clone()))?;
896 let enc_device = &enc_device_owned;
897 let on_gpu = !enc_device.is_cpu();
898 let enc_dtype = if on_gpu { gpu_dtype } else { DType::F32 };
899 let bf16_cfg = self.qwen3_bf16_config();
900
901 let enc_stage_label = format!("Loading Qwen3 encoder ({device_label})");
902 self.base.progress.stage_start(&enc_stage_label);
903 let enc_stage = Instant::now();
904 let text_tokenizer = self.load_text_tokenizer(&text_tokenizer_path)?;
905
906 let text_encoder = if is_gguf {
907 encoders::qwen3::Qwen3Encoder::load_gguf_with_tokenizer(
908 &encoder_paths[0],
909 &text_tokenizer_path,
910 Some(text_tokenizer),
911 enc_device,
912 &bf16_cfg,
913 )?
914 } else {
915 encoders::qwen3::Qwen3Encoder::load_bf16_with_tokenizer(
916 &encoder_paths,
917 &text_tokenizer_path,
918 Some(text_tokenizer),
919 enc_device,
920 enc_dtype,
921 &bf16_cfg,
922 &self.base.progress,
923 )?
924 };
925 self.base
926 .progress
927 .stage_done(&enc_stage_label, enc_stage.elapsed());
928 tracing::info!(device = %device_label, "Qwen3 encoder loaded");
929
930 self.base.loaded = Some(LoadedFlux2 {
931 transformer: Some(transformer),
932 text_encoder,
933 vae,
934 device,
935 dtype: gpu_dtype,
936 vae_dtype,
937 });
938
939 tracing::info!(model = %self.base.model_name, "all Flux.2 model components loaded successfully");
940 Ok(())
941 }
942
943 fn generate_sequential(&mut self, req: &GenerateRequest) -> Result<GenerateResponse> {
945 let text_tokenizer_path = self.validate_paths()?;
946 let is_gguf = self.is_gguf_transformer();
947
948 match flux2_offload_decision(self.offload, is_gguf, !self.pending_loras.is_empty()) {
949 Flux2OffloadDecision::Disabled => {}
950 Flux2OffloadDecision::Unsupported(reason) => bail!("{reason}"),
951 Flux2OffloadDecision::Selected => {}
952 }
953
954 if let Some(warning) = check_memory_budget(&self.base.paths, LoadStrategy::Sequential) {
955 self.base.progress.info(&warning);
956 }
957
958 let transformer_ref = effective_device_ref(
959 self.pending_placement.as_ref(),
960 |adv| Some(adv.transformer),
961 false,
962 );
963 let device = crate::device::resolve_device(Some(transformer_ref), || {
964 crate::device::create_device(self.base.gpu_ordinal, &self.base.progress)
965 })?;
966 let gpu_dtype = crate::engine::gpu_dtype(&device);
967
968 let start = Instant::now();
969 let seed = req.seed.unwrap_or_else(rand_seed);
970
971 let width = req.width as usize;
972 let height = req.height as usize;
973
974 tracing::info!(
975 prompt = %req.prompt,
976 seed, width, height,
977 steps = req.steps,
978 "starting sequential Flux.2 generation"
979 );
980
981 self.base
982 .progress
983 .info("Using sequential loading (load-use-drop) to minimize peak memory");
984
985 let cache_key = prompt_text_key(&req.prompt);
989 let txt_emb = if let Some(tensor) =
990 restore_cached_tensor(&self.prompt_cache, &cache_key, &device, gpu_dtype)?
991 {
992 self.base.progress.cache_hit("prompt conditioning");
993 tensor
994 } else {
995 let free = usable_free_vram_bytes(self.base.gpu_ordinal).unwrap_or(0);
997 self.base.progress.stage_start("Selecting Qwen3 encoder");
998 let resolve_start = Instant::now();
999 let qwen3_size = self.qwen3_size();
1000 let (encoder_paths, is_gguf, on_gpu, device_label) = {
1001 let bf16_paths = self.text_encoder_paths();
1002 let have_bf16 = !bf16_paths.is_empty() && bf16_paths.iter().all(|p| p.exists());
1003 crate::encoders::variant_resolution::resolve_qwen3_variant(
1004 &self.base.progress,
1005 self.qwen3_variant.as_deref(),
1006 &device,
1007 free,
1008 &bf16_paths,
1009 have_bf16,
1010 true,
1011 qwen3_size,
1012 )?
1013 };
1014 self.base
1015 .progress
1016 .stage_done("Selecting Qwen3 encoder", resolve_start.elapsed());
1017
1018 let qwen3_ref =
1019 effective_device_ref(self.pending_placement.as_ref(), |adv| adv.qwen, true);
1020 let auto_enc_device = if on_gpu { device.clone() } else { Device::Cpu };
1021 let enc_device_owned =
1022 crate::device::resolve_device(Some(qwen3_ref), || Ok(auto_enc_device.clone()))?;
1023 let enc_device = &enc_device_owned;
1024 let on_gpu = !enc_device.is_cpu();
1025 let enc_dtype = if on_gpu { gpu_dtype } else { DType::F32 };
1026 let bf16_cfg = self.qwen3_bf16_config();
1027
1028 let enc_size: u64 = encoder_paths
1030 .iter()
1031 .filter_map(|p| std::fs::metadata(p).ok().map(|m| m.len()))
1032 .sum();
1033 let enc_activation_budget = crate::device::activation_bytes(
1034 req.width,
1035 req.height,
1036 1,
1037 crate::device::dtype_bytes(enc_dtype),
1038 crate::device::ActivationFamily::SmallTransformer,
1039 );
1040 preflight_memory_check("Qwen3 encoder", enc_size, enc_activation_budget)?;
1041 if let Some(status) = memory_status_string() {
1042 self.base.progress.info(&status);
1043 }
1044
1045 let enc_stage_label = format!("Loading Qwen3 encoder ({device_label})");
1046 self.base.progress.stage_start(&enc_stage_label);
1047 let enc_stage = Instant::now();
1048 let text_tokenizer = self.load_text_tokenizer(&text_tokenizer_path)?;
1049
1050 let mut text_encoder = if is_gguf {
1051 encoders::qwen3::Qwen3Encoder::load_gguf_with_tokenizer(
1052 &encoder_paths[0],
1053 &text_tokenizer_path,
1054 Some(text_tokenizer),
1055 enc_device,
1056 &bf16_cfg,
1057 )?
1058 } else {
1059 encoders::qwen3::Qwen3Encoder::load_bf16_with_tokenizer(
1060 &encoder_paths,
1061 &text_tokenizer_path,
1062 Some(text_tokenizer),
1063 enc_device,
1064 enc_dtype,
1065 &bf16_cfg,
1066 &self.base.progress,
1067 )?
1068 };
1069 self.base
1070 .progress
1071 .stage_done(&enc_stage_label, enc_stage.elapsed());
1072
1073 let txt_emb = Self::encode_prompt_cached(
1074 &self.base.progress,
1075 &self.prompt_cache,
1076 &mut text_encoder,
1077 &req.prompt,
1078 &device,
1079 gpu_dtype,
1080 )?;
1081
1082 drop(text_encoder);
1084 self.base.progress.info("Freed Qwen3 encoder");
1085 tracing::info!("Qwen3 encoder dropped (sequential mode)");
1086
1087 txt_emb
1088 };
1089
1090 let latent_h = height.div_ceil(8);
1091 let latent_w = width.div_ceil(8);
1092
1093 let image_seq_len = (height / 16) * (width / 16);
1095 let mut timesteps = sampling::get_schedule(req.steps as usize, image_seq_len);
1096
1097 if req.source_image.is_some() {
1098 let (trimmed, start_index) =
1099 crate::img2img::trim_schedule_tail(×teps, req.steps as usize, req.strength);
1100 timesteps = trimmed;
1101 tracing::info!(
1102 strength = req.strength,
1103 start_index,
1104 start_timestep = timesteps[0],
1105 schedule = ?timesteps,
1106 remaining_steps = timesteps.len().saturating_sub(1),
1107 "img2img: truncated schedule from strength"
1108 );
1109 }
1110
1111 let (img, inpaint_ctx) = if let Some(ref source_bytes) = req.source_image {
1116 let start_t = timesteps[0];
1117 let (vae, _vae_dtype) = self.load_sequential_vae(&device, gpu_dtype)?;
1118
1119 self.base
1120 .progress
1121 .stage_start("Encoding source image (VAE)");
1122 let encode_start = Instant::now();
1123 let source_tensor = crate::img_utils::decode_source_image(
1124 source_bytes,
1125 req.width,
1126 req.height,
1127 Self::img2img_source_normalize_range(),
1128 &device,
1129 gpu_dtype,
1130 )?;
1131 let encoded = vae.encode(&source_tensor)?;
1132 self.base
1133 .progress
1134 .stage_done("Encoding source image (VAE)", encode_start.elapsed());
1135
1136 let prepared = crate::img2img::prepare_flow_match_img2img(
1137 &encoded,
1138 seed,
1139 &[1, 32, latent_h, latent_w],
1140 start_t,
1141 req.mask_image.as_deref(),
1142 latent_h,
1143 latent_w,
1144 &device,
1145 gpu_dtype,
1146 )?;
1147 drop(vae);
1148 drop(encoded);
1149 drop(source_tensor);
1150 device.synchronize()?;
1151 self.base.progress.info("Freed VAE after source encoding");
1152 (prepared.initial_latents, prepared.inpaint_ctx)
1153 } else {
1154 let img = crate::engine::seeded_randn(
1155 seed,
1156 &[1, 32, latent_h, latent_w],
1157 &device,
1158 gpu_dtype,
1159 )?;
1160 (img, None)
1161 };
1162
1163 let state = Flux2State::new(&txt_emb, &img)?;
1164 let inpaint_ctx = inpaint_ctx
1165 .as_ref()
1166 .map(crate::img2img::pack_flux_inpaint_context)
1167 .transpose()?;
1168
1169 let xformer_size = std::fs::metadata(&self.base.paths.transformer)
1171 .map(|m| m.len())
1172 .unwrap_or(0);
1173 let xformer_activation_budget = crate::device::activation_bytes(
1174 req.width,
1175 req.height,
1176 1,
1177 crate::device::dtype_bytes(gpu_dtype),
1178 crate::device::ActivationFamily::Flux2Dit,
1179 );
1180 preflight_memory_check(
1181 "Flux.2 transformer",
1182 xformer_size,
1183 xformer_activation_budget,
1184 )?;
1185 if let Some(status) = memory_status_string() {
1186 self.base.progress.info(&status);
1187 }
1188
1189 let flux2_cfg = self.resolve_config();
1190 let xformer_stage = Instant::now();
1191 let (transformer, xformer_label) = self.load_transformer(&flux2_cfg, gpu_dtype, &device)?;
1192 self.base
1193 .progress
1194 .stage_done(xformer_label, xformer_stage.elapsed());
1195
1196 let denoise_label = format!("Denoising ({} steps)", timesteps.len().saturating_sub(1));
1197 self.base.progress.stage_start(&denoise_label);
1198 let denoise_start = Instant::now();
1199
1200 let img = transformer.denoise(
1201 &state.img,
1202 &state.img_ids,
1203 &state.txt,
1204 &state.txt_ids,
1205 &state.vec,
1206 ×teps,
1207 req.guidance,
1208 &self.base.progress,
1209 inpaint_ctx.as_ref(),
1210 )?;
1211
1212 let img = sampling::unpack(&img, height, width)?;
1213
1214 self.base
1215 .progress
1216 .stage_done(&denoise_label, denoise_start.elapsed());
1217
1218 drop(inpaint_ctx);
1220 drop(transformer);
1221 self.base.progress.info("Freed Flux.2 transformer");
1222 drop(state);
1223 drop(txt_emb);
1224 device.synchronize()?;
1225 tracing::info!("Transformer dropped (sequential mode), decoding VAE...");
1226
1227 let (vae, vae_dtype) = self.load_sequential_vae(&device, gpu_dtype)?;
1228
1229 self.base.progress.stage_start("VAE decode");
1231 let vae_decode_start = Instant::now();
1232 if let Ok(dump_path) = std::env::var("MOLD_FLUX2_DUMP_LATENT") {
1234 let latent_f32 = img
1235 .to_dtype(DType::F32)?
1236 .to_device(&candle_core::Device::Cpu)?;
1237 let dims = latent_f32.dims().to_vec();
1238 let v: Vec<f32> = latent_f32.flatten_all()?.to_vec1()?;
1239 let mut bytes = Vec::with_capacity(8 * 4 + v.len() * 4);
1240 bytes.extend_from_slice(&(dims.len() as u32).to_le_bytes());
1241 for d in &dims {
1242 bytes.extend_from_slice(&(*d as u32).to_le_bytes());
1243 }
1244 for x in &v {
1245 bytes.extend_from_slice(&x.to_le_bytes());
1246 }
1247 std::fs::write(&dump_path, &bytes)?;
1248 tracing::info!(path = %dump_path, dims = ?dims, "dumped pre-VAE latent");
1249 }
1250 let img_for_vae = img.to_dtype(vae_dtype)?;
1251 let device_for_sync = device.clone();
1252 let img = crate::vae_tiling::decode_with_oom_fallback(
1253 &img_for_vae,
1254 |latents| vae.decode(latents).map_err(Into::into),
1255 || {
1256 if let Err(e) = device_for_sync.synchronize() {
1257 tracing::warn!(
1258 "FLUX2 (sequential) device.synchronize() after VAE OOM failed: {e}"
1259 );
1260 }
1261 },
1262 )?;
1263
1264 let img = ((img.clamp(-1f32, 1f32)? + 1.0)? * 127.5)?.to_dtype(DType::U8)?;
1265 let img = img.i(0)?;
1266
1267 self.base
1268 .progress
1269 .stage_done("VAE decode", vae_decode_start.elapsed());
1270
1271 let output_metadata = build_output_metadata(req, seed, None);
1272 let image_bytes = encode_image(
1273 &img,
1274 req.resolved_output_format(),
1275 req.width,
1276 req.height,
1277 output_metadata.as_ref(),
1278 )?;
1279
1280 let generation_time_ms = start.elapsed().as_millis() as u64;
1281 tracing::info!(generation_time_ms, seed, "sequential generation complete");
1282
1283 Ok(GenerateResponse {
1284 images: vec![ImageData {
1285 data: image_bytes,
1286 format: req.resolved_output_format(),
1287 width: req.width,
1288 height: req.height,
1289 index: 0,
1290 }],
1291 generation_time_ms,
1292 model: req.model.clone(),
1293 seed_used: seed,
1294 video: None,
1295 gpu: None,
1296 })
1297 }
1298}
1299
1300impl Flux2Engine {
1305 fn generate_inner(&mut self, req: &GenerateRequest) -> Result<GenerateResponse> {
1306 if req.scheduler.is_some() {
1307 tracing::warn!(
1308 "scheduler selection not supported for Flux.2 (flow-matching), ignoring"
1309 );
1310 }
1311 if req.guidance != 0.0 {
1312 tracing::debug!(
1313 guidance = req.guidance,
1314 "Flux.2 Klein is distilled — guidance value is ignored (no guidance embedding)"
1315 );
1316 }
1317 if self.uses_sequential_generate_path() {
1319 return self.generate_sequential(req);
1320 }
1321
1322 let delay_transformer_reload = self.base.loaded.as_ref().is_some_and(|loaded| {
1330 Self::should_delay_transformer_reload_for_prompt_encode(
1331 self.base.load_strategy,
1332 loaded.transformer.is_some(),
1333 )
1334 });
1335 if delay_transformer_reload {
1336 tracing::info!(
1337 "delaying Flux.2 transformer reload until after prompt encode to reduce peak VRAM"
1338 );
1339 }
1340
1341 let start = Instant::now();
1342 let seed = req.seed.unwrap_or_else(rand_seed);
1343
1344 let width = req.width as usize;
1345 let height = req.height as usize;
1346
1347 tracing::info!(
1348 prompt = %req.prompt,
1349 seed, width, height,
1350 steps = req.steps,
1351 "starting Flux.2 generation"
1352 );
1353
1354 let txt_emb = {
1356 let progress = &self.base.progress;
1357 let loaded = self
1358 .base
1359 .loaded
1360 .as_mut()
1361 .ok_or_else(|| anyhow::anyhow!("model not loaded — call load() first"))?;
1362 let cache_key = prompt_text_key(&req.prompt);
1363 if let Some(tensor) =
1364 restore_cached_tensor(&self.prompt_cache, &cache_key, &loaded.device, loaded.dtype)?
1365 {
1366 progress.cache_hit("prompt conditioning");
1367 tensor
1368 } else {
1369 if loaded.text_encoder.model.is_none() {
1372 let label = if loaded.text_encoder.is_parked() {
1373 "Unparking Qwen3 encoder (CPU→GPU)"
1374 } else {
1375 "Reloading Qwen3 encoder"
1376 };
1377 progress.stage_start(label);
1378 let reload_start = Instant::now();
1379 if loaded.text_encoder.is_parked() {
1380 loaded.text_encoder.unpark_to_gpu(progress)?;
1381 } else {
1382 loaded.text_encoder.reload(progress)?;
1383 }
1384 progress.stage_done(label, reload_start.elapsed());
1385 }
1386
1387 let txt_emb = Self::encode_prompt_cached(
1388 progress,
1389 &self.prompt_cache,
1390 &mut loaded.text_encoder,
1391 &req.prompt,
1392 &loaded.device,
1393 loaded.dtype,
1394 )?;
1395 tracing::info!("Qwen3 encoding complete");
1396
1397 if loaded.text_encoder.on_gpu || loaded.device.is_metal() {
1402 let park_mode = crate::device::keep_te_in_ram()
1403 && !loaded.device.is_metal()
1404 && !loaded.text_encoder.is_quantized;
1405 if park_mode {
1406 loaded.text_encoder.park_to_cpu()?;
1407 tracing::info!(
1408 on_gpu = loaded.text_encoder.on_gpu,
1409 "Qwen3 encoder parked to CPU host RAM"
1410 );
1411 } else {
1412 loaded.text_encoder.drop_weights();
1413 tracing::info!(
1414 on_gpu = loaded.text_encoder.on_gpu,
1415 "Qwen3 encoder dropped to free memory for denoising"
1416 );
1417 }
1418 }
1419
1420 txt_emb
1421 }
1422 };
1423
1424 self.reload_transformer_if_needed()?;
1425
1426 let loaded = self
1427 .base
1428 .loaded
1429 .as_mut()
1430 .ok_or_else(|| anyhow::anyhow!("model not loaded — call load() first"))?;
1431 let progress = &self.base.progress;
1432
1433 let latent_h = height.div_ceil(8);
1435 let latent_w = width.div_ceil(8);
1436
1437 let image_seq_len = (height / 16) * (width / 16);
1439 let mut timesteps = sampling::get_schedule(req.steps as usize, image_seq_len);
1440
1441 if req.source_image.is_some() {
1442 let (trimmed, start_index) =
1443 crate::img2img::trim_schedule_tail(×teps, req.steps as usize, req.strength);
1444 timesteps = trimmed;
1445 tracing::info!(
1446 strength = req.strength,
1447 start_index,
1448 start_timestep = timesteps[0],
1449 schedule = ?timesteps,
1450 remaining_steps = timesteps.len().saturating_sub(1),
1451 "img2img: truncated schedule from strength"
1452 );
1453 }
1454
1455 let (img, inpaint_ctx) = if let Some(ref source_bytes) = req.source_image {
1457 let start_t = timesteps[0];
1458
1459 progress.stage_start("Encoding source image (VAE)");
1460 let encode_start = Instant::now();
1461 let source_tensor = crate::img_utils::decode_source_image(
1462 source_bytes,
1463 req.width,
1464 req.height,
1465 Self::img2img_source_normalize_range(),
1466 &loaded.device,
1467 loaded.vae_dtype,
1468 )?;
1469 let encoded = loaded.vae.encode(&source_tensor)?;
1470 progress.stage_done("Encoding source image (VAE)", encode_start.elapsed());
1471
1472 let prepared = crate::img2img::prepare_flow_match_img2img(
1473 &encoded,
1474 seed,
1475 &[1, 32, latent_h, latent_w],
1476 start_t,
1477 req.mask_image.as_deref(),
1478 latent_h,
1479 latent_w,
1480 &loaded.device,
1481 loaded.dtype,
1482 )?;
1483 (prepared.initial_latents, prepared.inpaint_ctx)
1484 } else {
1485 let img = crate::engine::seeded_randn(
1486 seed,
1487 &[1, 32, latent_h, latent_w],
1488 &loaded.device,
1489 loaded.dtype,
1490 )?;
1491 (img, None)
1492 };
1493
1494 let state = Flux2State::new(&txt_emb, &img)?;
1496 let inpaint_ctx = inpaint_ctx
1497 .as_ref()
1498 .map(crate::img2img::pack_flux_inpaint_context)
1499 .transpose()?;
1500
1501 let denoise_label = format!("Denoising ({} steps)", timesteps.len().saturating_sub(1));
1502 progress.stage_start(&denoise_label);
1503 let denoise_start = Instant::now();
1504 tracing::info!(
1505 steps = timesteps.len().saturating_sub(1),
1506 "running denoising loop..."
1507 );
1508
1509 let transformer = loaded
1511 .transformer
1512 .as_ref()
1513 .ok_or_else(|| anyhow::anyhow!("transformer not loaded"))?;
1514 let img = transformer.denoise(
1515 &state.img,
1516 &state.img_ids,
1517 &state.txt,
1518 &state.txt_ids,
1519 &state.vec,
1520 ×teps,
1521 req.guidance,
1522 progress,
1523 inpaint_ctx.as_ref(),
1524 )?;
1525
1526 let img = sampling::unpack(&img, height, width)?;
1528 progress.stage_done(&denoise_label, denoise_start.elapsed());
1529 tracing::info!("denoising complete, decoding VAE...");
1530
1531 drop(inpaint_ctx);
1535 drop(state);
1536 drop(txt_emb);
1537 loaded.transformer = None;
1538 loaded.device.synchronize()?;
1542 tracing::info!("Transformer dropped to free VRAM for VAE decode");
1543
1544 progress.stage_start("VAE decode");
1546 let vae_decode_start = Instant::now();
1547 if let Ok(dump_path) = std::env::var("MOLD_FLUX2_DUMP_LATENT") {
1549 let latent_f32 = img
1550 .to_dtype(DType::F32)?
1551 .to_device(&candle_core::Device::Cpu)?;
1552 let dims = latent_f32.dims().to_vec();
1553 let v: Vec<f32> = latent_f32.flatten_all()?.to_vec1()?;
1554 let mut bytes = Vec::with_capacity(8 * 4 + v.len() * 4);
1555 bytes.extend_from_slice(&(dims.len() as u32).to_le_bytes());
1556 for d in &dims {
1557 bytes.extend_from_slice(&(*d as u32).to_le_bytes());
1558 }
1559 for x in &v {
1560 bytes.extend_from_slice(&x.to_le_bytes());
1561 }
1562 std::fs::write(&dump_path, &bytes)?;
1563 tracing::info!(path = %dump_path, dims = ?dims, "dumped pre-VAE latent (parallel)");
1564 }
1565 let img_for_vae = img.to_dtype(loaded.vae_dtype)?;
1566 let vae = &loaded.vae;
1567 let device_for_sync = loaded.device.clone();
1568 let img = crate::vae_tiling::decode_with_oom_fallback(
1569 &img_for_vae,
1570 |latents| vae.decode(latents).map_err(Into::into),
1571 || {
1572 if let Err(e) = device_for_sync.synchronize() {
1573 tracing::warn!(
1574 "FLUX2 (parallel) device.synchronize() after VAE OOM failed: {e}"
1575 );
1576 }
1577 },
1578 )?;
1579
1580 let img = ((img.clamp(-1f32, 1f32)? + 1.0)? * 127.5)?.to_dtype(DType::U8)?;
1582 let img = img.i(0)?; progress.stage_done("VAE decode", vae_decode_start.elapsed());
1585 tracing::info!("VAE decode complete, encoding output image...");
1586
1587 let output_metadata = build_output_metadata(req, seed, None);
1589 let image_bytes = encode_image(
1590 &img,
1591 req.resolved_output_format(),
1592 req.width,
1593 req.height,
1594 output_metadata.as_ref(),
1595 )?;
1596
1597 let generation_time_ms = start.elapsed().as_millis() as u64;
1598 tracing::info!(generation_time_ms, seed, "generation complete");
1599
1600 Ok(GenerateResponse {
1601 images: vec![ImageData {
1602 data: image_bytes,
1603 format: req.resolved_output_format(),
1604 width: req.width,
1605 height: req.height,
1606 index: 0,
1607 }],
1608 generation_time_ms,
1609 model: req.model.clone(),
1610 seed_used: seed,
1611 video: None,
1612 gpu: None,
1613 })
1614 }
1615}
1616
1617impl InferenceEngine for Flux2Engine {
1618 fn generate(&mut self, req: &GenerateRequest) -> Result<GenerateResponse> {
1619 self.pending_placement = req.placement.clone();
1620 self.pending_loras = effective_flux2_loras(req);
1621 let result = self.generate_inner(req);
1622 self.pending_placement = None;
1623 self.pending_loras.clear();
1624 result
1625 }
1626
1627 fn model_name(&self) -> &str {
1628 self.base.model_name()
1629 }
1630
1631 fn is_loaded(&self) -> bool {
1632 self.base.is_loaded()
1633 }
1634
1635 fn load(&mut self) -> Result<()> {
1636 Flux2Engine::load(self)
1637 }
1638
1639 fn unload(&mut self) {
1640 self.base.unload();
1641 clear_cache(&self.prompt_cache);
1642 }
1643
1644 fn set_on_progress(&mut self, callback: ProgressCallback) {
1645 self.base.set_on_progress(callback);
1646 }
1647
1648 fn clear_on_progress(&mut self) {
1649 self.base.clear_on_progress();
1650 }
1651
1652 fn model_paths(&self) -> Option<&mold_core::ModelPaths> {
1653 Some(&self.base.paths)
1654 }
1655}
1656
1657#[cfg(test)]
1658mod tests {
1659 use super::*;
1660 use crate::encoders::variant_resolution::Qwen3Size;
1661 use crate::engine::LoadStrategy;
1662 use crate::shared_pool::SharedPool;
1663 use mold_core::ModelPaths;
1664 use safetensors::tensor::{serialize_to_file, Dtype as SafeDtype, TensorView};
1665 use std::collections::HashMap;
1666 use std::fs;
1667 use std::path::{Path, PathBuf};
1668 use std::sync::{Arc, Mutex};
1669 use std::time::{SystemTime, UNIX_EPOCH};
1670 use tokenizers::models::bpe::BPE;
1671
1672 fn temp_test_dir(prefix: &str) -> PathBuf {
1673 let suffix = SystemTime::now()
1674 .duration_since(UNIX_EPOCH)
1675 .unwrap()
1676 .as_nanos();
1677 let dir = std::env::temp_dir().join(format!("{prefix}-{}-{suffix}", std::process::id()));
1678 fs::create_dir_all(&dir).unwrap();
1679 dir
1680 }
1681
1682 fn touch(dir: &Path, name: &str) -> PathBuf {
1683 let path = dir.join(name);
1684 fs::write(&path, b"test").unwrap();
1685 path
1686 }
1687
1688 fn flux2_model_paths(
1689 dir: &Path,
1690 transformer_name: &str,
1691 text_encoder_files: Vec<PathBuf>,
1692 t5_encoder: Option<PathBuf>,
1693 ) -> ModelPaths {
1694 ModelPaths {
1695 transformer: dir.join(transformer_name),
1696 transformer_shards: vec![],
1697 vae: dir.join("vae.safetensors"),
1698 spatial_upscaler: None,
1699 temporal_upscaler: None,
1700 distilled_lora: None,
1701 t5_encoder,
1702 clip_encoder: None,
1703 t5_tokenizer: None,
1704 clip_tokenizer: None,
1705 clip_encoder_2: None,
1706 clip_tokenizer_2: None,
1707 text_encoder_files,
1708 text_tokenizer: Some(dir.join("tokenizer.json")),
1709 decoder: None,
1710 }
1711 }
1712
1713 #[test]
1714 fn flux2_img2img_uses_minus_one_to_one_source_normalization() {
1715 assert_eq!(
1716 Flux2Engine::img2img_source_normalize_range(),
1717 crate::img_utils::NormalizeRange::MinusOneToOne
1718 );
1719 }
1720
1721 #[test]
1722 fn sequential_img2img_encodes_source_before_transformer_load() {
1723 assert!(
1724 Flux2Engine::sequential_img2img_preencodes_source(),
1725 "sequential Flux.2 img2img must not keep the VAE resident while loading the transformer"
1726 );
1727 }
1728
1729 #[test]
1730 fn eager_warm_request_delays_transformer_reload_until_after_prompt_encode() {
1731 assert!(
1732 Flux2Engine::should_delay_transformer_reload_for_prompt_encode(
1733 LoadStrategy::Eager,
1734 false
1735 ),
1736 "warm eager requests with a dropped transformer must encode/drop Qwen3 before reload"
1737 );
1738 assert!(
1739 !Flux2Engine::should_delay_transformer_reload_for_prompt_encode(
1740 LoadStrategy::Eager,
1741 true
1742 ),
1743 "fully loaded eager requests should keep the existing hot path"
1744 );
1745 assert!(
1746 !Flux2Engine::should_delay_transformer_reload_for_prompt_encode(
1747 LoadStrategy::Sequential,
1748 false
1749 ),
1750 "sequential mode already handles load-use-drop ordering"
1751 );
1752 }
1753
1754 #[test]
1755 fn flux2_model_name_controls_transformer_and_encoder_config() {
1756 let base_dir = temp_test_dir("mold-flux2-config");
1757 let standard = Flux2Engine::new(
1758 "flux2-klein:q8".to_string(),
1759 flux2_model_paths(&base_dir, "transformer.gguf", vec![], None),
1760 None,
1761 LoadStrategy::Sequential,
1762 0,
1763 false,
1764 None,
1765 );
1766 let nine_b = Flux2Engine::new(
1767 "flux2-klein-9b:q8".to_string(),
1768 flux2_model_paths(&base_dir, "transformer.gguf", vec![], None),
1769 None,
1770 LoadStrategy::Sequential,
1771 0,
1772 false,
1773 None,
1774 );
1775
1776 let standard_cfg = standard.resolve_config();
1777 let nine_b_cfg = nine_b.resolve_config();
1778
1779 assert_eq!(standard_cfg.hidden_size, 3072);
1780 assert_eq!(standard_cfg.context_in_dim, 7680);
1781 assert_eq!(standard.qwen3_size(), Qwen3Size::B4);
1782 assert_eq!(standard.qwen3_bf16_config().hidden_size, 2560);
1783
1784 assert_eq!(nine_b_cfg.hidden_size, 4096);
1785 assert_eq!(nine_b_cfg.context_in_dim, 12288);
1786 assert_eq!(nine_b.qwen3_size(), Qwen3Size::B8);
1787 assert_eq!(nine_b.qwen3_bf16_config().hidden_size, 4096);
1788
1789 fs::remove_dir_all(base_dir).ok();
1790 }
1791
1792 #[test]
1793 fn flux2_text_encoder_paths_use_shards_or_t5_fallback() {
1794 let dir = temp_test_dir("mold-flux2-paths");
1795 let shard_a = touch(&dir, "encoder-1.safetensors");
1796 let shard_b = touch(&dir, "encoder-2.safetensors");
1797 let fallback = touch(&dir, "encoder.safetensors");
1798
1799 let sharded = Flux2Engine::new(
1800 "flux2-klein:q8".to_string(),
1801 flux2_model_paths(
1802 &dir,
1803 "transformer.gguf",
1804 vec![shard_a.clone(), shard_b.clone()],
1805 Some(fallback.clone()),
1806 ),
1807 None,
1808 LoadStrategy::Sequential,
1809 0,
1810 false,
1811 None,
1812 );
1813 assert_eq!(sharded.text_encoder_paths(), vec![shard_a, shard_b]);
1814
1815 let fallback_engine = Flux2Engine::new(
1816 "flux2-klein:q8".to_string(),
1817 flux2_model_paths(&dir, "transformer.gguf", vec![], Some(fallback.clone())),
1818 None,
1819 LoadStrategy::Sequential,
1820 0,
1821 false,
1822 None,
1823 );
1824 assert_eq!(fallback_engine.text_encoder_paths(), vec![fallback]);
1825
1826 fs::remove_dir_all(dir).ok();
1827 }
1828
1829 #[test]
1830 fn flux2_loads_qwen3_tokenizer_through_shared_pool() {
1831 let dir = temp_test_dir("mold-flux2-tokenizer-pool");
1832 let tokenizer_path = dir.join("tokenizer.json");
1833 tokenizers::Tokenizer::new(BPE::default())
1834 .save(&tokenizer_path, false)
1835 .unwrap();
1836
1837 let shared_pool = Arc::new(Mutex::new(SharedPool::new()));
1838 let pooled = shared_pool
1839 .lock()
1840 .unwrap()
1841 .load_tokenizer(&tokenizer_path)
1842 .unwrap();
1843
1844 let engine = Flux2Engine::new(
1845 "flux2-klein:q8".to_string(),
1846 flux2_model_paths(&dir, "transformer.gguf", vec![], None),
1847 None,
1848 LoadStrategy::Sequential,
1849 0,
1850 false,
1851 Some(shared_pool),
1852 );
1853
1854 let loaded = engine.load_text_tokenizer(&tokenizer_path).unwrap();
1855
1856 assert!(Arc::ptr_eq(&pooled, &loaded));
1857 fs::remove_dir_all(dir).ok();
1858 }
1859
1860 #[test]
1861 fn flux2_forced_offload_uses_sequential_generation_path() {
1862 let dir = temp_test_dir("mold-flux2-offload-sequential");
1863 let engine = Flux2Engine::new(
1864 "flux2-klein:bf16".to_string(),
1865 flux2_model_paths(&dir, "transformer.safetensors", vec![], None),
1866 None,
1867 LoadStrategy::Eager,
1868 0,
1869 true,
1870 None,
1871 );
1872
1873 assert!(
1874 engine.uses_sequential_generate_path(),
1875 "Flux.2 --offload requests must reach the engine and select the \
1876 staged generation path instead of being silently ignored"
1877 );
1878
1879 fs::remove_dir_all(dir).ok();
1880 }
1881
1882 #[test]
1883 fn flux2_offload_decision_gates_current_unsupported_cases() {
1884 assert_eq!(
1885 flux2_offload_decision(false, false, false),
1886 Flux2OffloadDecision::Disabled
1887 );
1888 assert_eq!(
1889 flux2_offload_decision(true, false, false),
1890 Flux2OffloadDecision::Selected
1891 );
1892 assert!(matches!(
1893 flux2_offload_decision(true, true, false),
1894 Flux2OffloadDecision::Unsupported(reason)
1895 if reason.contains("GGUF variants")
1896 ));
1897 assert!(matches!(
1898 flux2_offload_decision(true, false, true),
1899 Flux2OffloadDecision::Unsupported(reason)
1900 if reason.contains("LoRA")
1901 ));
1902 }
1903
1904 #[test]
1905 fn flux2_selected_bf16_offload_reaches_runtime_loader() {
1906 let dir = temp_test_dir("mold-flux2-offload-loader");
1907 let transformer = touch(&dir, "transformer.safetensors");
1908 let vae = touch(&dir, "vae.safetensors");
1909 let encoder = touch(&dir, "encoder.safetensors");
1910 let tokenizer = touch(&dir, "tokenizer.json");
1911 let mut engine = Flux2Engine::new(
1912 "flux2-klein:bf16".to_string(),
1913 ModelPaths {
1914 transformer,
1915 transformer_shards: vec![],
1916 vae,
1917 spatial_upscaler: None,
1918 temporal_upscaler: None,
1919 distilled_lora: None,
1920 t5_encoder: None,
1921 clip_encoder: None,
1922 t5_tokenizer: None,
1923 clip_tokenizer: None,
1924 clip_encoder_2: None,
1925 clip_tokenizer_2: None,
1926 text_encoder_files: vec![encoder],
1927 text_tokenizer: Some(tokenizer),
1928 decoder: None,
1929 },
1930 None,
1931 LoadStrategy::Sequential,
1932 0,
1933 true,
1934 None,
1935 );
1936 let cfg = engine.resolve_config();
1937 let txt_emb = Tensor::zeros((1, 1, cfg.context_in_dim), DType::F32, &Device::Cpu).unwrap();
1938 engine.prompt_cache.lock().unwrap().insert(
1939 prompt_text_key("a cat"),
1940 CachedTensor::from_tensor(&txt_emb).unwrap(),
1941 );
1942 let req = GenerateRequest {
1943 prompt: "a cat".to_string(),
1944 negative_prompt: None,
1945 model: "flux2-klein:bf16".to_string(),
1946 width: 64,
1947 height: 64,
1948 steps: 1,
1949 guidance: 0.0,
1950 seed: Some(1),
1951 batch_size: 1,
1952 output_format: None,
1953 embed_metadata: None,
1954 scheduler: None,
1955 cfg_plus: None,
1956 source_image: None,
1957 edit_images: None,
1958 strength: 1.0,
1959 mask_image: None,
1960 control_image: None,
1961 control_model: None,
1962 control_scale: 1.0,
1963 expand: None,
1964 original_prompt: None,
1965 lora: None,
1966 frames: None,
1967 fps: None,
1968 upscale_model: None,
1969 gif_preview: false,
1970 enable_audio: None,
1971 audio_file: None,
1972 audio_file_path: None,
1973 source_video: None,
1974 source_video_path: None,
1975 keyframes: None,
1976 pipeline: None,
1977 loras: None,
1978 retake_range: None,
1979 spatial_upscale: None,
1980 temporal_upscale: None,
1981 placement: Some(mold_core::types::DevicePlacement {
1982 text_encoders: mold_core::types::DeviceRef::Cpu,
1983 advanced: Some(mold_core::types::AdvancedPlacement {
1984 transformer: mold_core::types::DeviceRef::Cpu,
1985 vae: mold_core::types::DeviceRef::Cpu,
1986 ..Default::default()
1987 }),
1988 }),
1989 };
1990
1991 let err = engine.generate_sequential(&req).unwrap_err().to_string();
1992
1993 assert!(
1994 !err.contains("streaming is not implemented yet"),
1995 "selected BF16 offload must reach the runtime loader, got: {err}"
1996 );
1997 fs::remove_dir_all(dir).ok();
1998 }
1999
2000 #[test]
2001 fn flux2_loads_vae_tensors_through_shared_pool() {
2002 let dir = temp_test_dir("mold-flux2-vae-pool");
2003 let vae_path = dir.join("vae.safetensors");
2004 let weight = 1.0f32.to_le_bytes();
2005 let mut tensors = HashMap::new();
2006 tensors.insert(
2007 "encoder.conv_in.weight".to_string(),
2008 TensorView::new(SafeDtype::F32, vec![1], &weight).unwrap(),
2009 );
2010 serialize_to_file(&tensors, &None, &vae_path).unwrap();
2011
2012 let shared_pool = Arc::new(Mutex::new(SharedPool::new()));
2013 let pooled = shared_pool
2014 .lock()
2015 .unwrap()
2016 .load_safetensors_cpu_tensors(std::slice::from_ref(&vae_path))
2017 .unwrap()
2018 .unwrap();
2019
2020 let engine = Flux2Engine::new(
2021 "flux2-klein:q8".to_string(),
2022 flux2_model_paths(&dir, "transformer.gguf", vec![], None),
2023 None,
2024 LoadStrategy::Sequential,
2025 0,
2026 false,
2027 Some(shared_pool),
2028 );
2029
2030 let loaded = engine.load_vae_cpu_tensors().unwrap().unwrap();
2031
2032 assert!(Arc::ptr_eq(&pooled, &loaded));
2033 fs::remove_dir_all(dir).ok();
2034 }
2035
2036 #[test]
2037 fn flux2_validate_paths_accepts_existing_files_and_returns_tokenizer() {
2038 let dir = temp_test_dir("mold-flux2-validate-ok");
2039 let transformer = touch(&dir, "transformer.gguf");
2040 let vae = touch(&dir, "vae.safetensors");
2041 let encoder = touch(&dir, "encoder.safetensors");
2042 let tokenizer = touch(&dir, "tokenizer.json");
2043
2044 let engine = Flux2Engine::new(
2045 "flux2-klein:q8".to_string(),
2046 ModelPaths {
2047 transformer,
2048 transformer_shards: vec![],
2049 vae,
2050 spatial_upscaler: None,
2051 temporal_upscaler: None,
2052 distilled_lora: None,
2053 t5_encoder: None,
2054 clip_encoder: None,
2055 t5_tokenizer: None,
2056 clip_tokenizer: None,
2057 clip_encoder_2: None,
2058 clip_tokenizer_2: None,
2059 text_encoder_files: vec![encoder],
2060 text_tokenizer: Some(tokenizer.clone()),
2061 decoder: None,
2062 },
2063 None,
2064 LoadStrategy::Sequential,
2065 0,
2066 false,
2067 None,
2068 );
2069
2070 assert_eq!(engine.validate_paths().unwrap(), tokenizer);
2071 assert!(engine.is_gguf_transformer());
2072
2073 fs::remove_dir_all(dir).ok();
2074 }
2075
2076 #[test]
2077 fn flux2_validate_paths_requires_text_encoder_paths() {
2078 let dir = temp_test_dir("mold-flux2-validate-missing");
2079 let transformer = touch(&dir, "transformer.safetensors");
2080 let vae = touch(&dir, "vae.safetensors");
2081 let tokenizer = touch(&dir, "tokenizer.json");
2082
2083 let engine = Flux2Engine::new(
2084 "flux2-klein:bf16".to_string(),
2085 ModelPaths {
2086 transformer,
2087 transformer_shards: vec![],
2088 vae,
2089 spatial_upscaler: None,
2090 temporal_upscaler: None,
2091 distilled_lora: None,
2092 t5_encoder: None,
2093 clip_encoder: None,
2094 t5_tokenizer: None,
2095 clip_tokenizer: None,
2096 clip_encoder_2: None,
2097 clip_tokenizer_2: None,
2098 text_encoder_files: vec![],
2099 text_tokenizer: Some(tokenizer),
2100 decoder: None,
2101 },
2102 None,
2103 LoadStrategy::Sequential,
2104 0,
2105 false,
2106 None,
2107 );
2108
2109 let err = engine.validate_paths().unwrap_err();
2110 assert!(err.to_string().contains("text encoder paths required"));
2111 assert!(!engine.is_gguf_transformer());
2112
2113 fs::remove_dir_all(dir).ok();
2114 }
2115}